1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace impl {
13 namespace reference {
14 namespace native {
15
16 using executorch::aten::Tensor;
17 using executorch::runtime::getLeadingDims;
18 using executorch::runtime::KernelRuntimeContext;
19
20 // The quantized matmul. The quantized matmul accumulates in a wider register,
21 // whose type is TA.
22 template <
23 typename TZ,
24 typename TA = float,
25 bool transposed = false,
26 typename TX = TZ,
27 typename TY = TZ>
qmatmul(TZ * __restrict__ Z,int32_t Z_multiplier,int32_t Z_shift,int32_t Z_zero_point,const TX * __restrict__ X,int32_t X_zero_point,const TY * __restrict__ y,int32_t Y_zero_point,size_t m,size_t n,size_t p)28 __attribute__((noinline)) void qmatmul(
29 TZ* __restrict__ Z,
30 int32_t Z_multiplier,
31 int32_t Z_shift,
32 int32_t Z_zero_point,
33 const TX* __restrict__ X,
34 int32_t X_zero_point,
35 const TY* __restrict__ y,
36 int32_t Y_zero_point,
37 size_t m,
38 size_t n,
39 size_t p) {
40 // Compute the Z_scale from Z_multiplier and Z_shift
41 const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift);
42 for (size_t i = 0; i < m; ++i) {
43 for (size_t j = 0; j < p; ++j) {
44 TA sum = 0;
45 for (size_t k = 0; k < n; ++k) {
46 if (transposed) {
47 sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point);
48 } else {
49 sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point);
50 }
51 }
52 Z[i * p + j] = kernels::quantize<TZ>(sum, Z_scale, Z_zero_point);
53 }
54 }
55 }
56
57 template <typename T>
_typed_quantized_matmul(const Tensor & X,int64_t X_zero_point,const Tensor & Y,int64_t Y_zero_point,const executorch::aten::optional<Tensor> & bias,int64_t out_multiplier,int64_t out_shift,int64_t out_zero_point,bool transposed,Tensor & out)58 void inline _typed_quantized_matmul(
59 const Tensor& X,
60 int64_t X_zero_point,
61 const Tensor& Y,
62 int64_t Y_zero_point,
63 const executorch::aten::optional<Tensor>& bias,
64 int64_t out_multiplier,
65 int64_t out_shift,
66 int64_t out_zero_point,
67 bool transposed,
68 Tensor& out) {
69 size_t batch_size = getLeadingDims(X, X.dim() - 2);
70 size_t leading_dim = X.size(X.dim() - 2);
71 size_t out_dim = Y.size(Y.dim() - 1 - transposed);
72 size_t in_dim = X.size(X.dim() - 1);
73
74 T* __restrict__ out_data = out.mutable_data_ptr<T>();
75 const T* __restrict__ X_data = X.const_data_ptr<T>();
76 const T* __restrict__ Y_data = Y.const_data_ptr<T>();
77 for (size_t i = 0; i < batch_size; ++i) {
78 const T* x = X_data + i * leading_dim * in_dim;
79 const T* y = Y_data + i * in_dim * out_dim;
80 T* z = out_data + i * leading_dim * out_dim;
81 if (transposed) {
82 qmatmul<T, int32_t, true>(
83 z,
84 static_cast<int32_t>(out_multiplier),
85 static_cast<int32_t>(out_shift),
86 static_cast<int32_t>(out_zero_point),
87 x,
88 static_cast<int32_t>(X_zero_point),
89 y,
90 static_cast<int32_t>(Y_zero_point),
91 leading_dim,
92 in_dim,
93 out_dim);
94 } else {
95 qmatmul<T, int32_t, false>(
96 z,
97 static_cast<int32_t>(out_multiplier),
98 static_cast<int32_t>(out_shift),
99 static_cast<int32_t>(out_zero_point),
100 x,
101 static_cast<int32_t>(X_zero_point),
102 y,
103 static_cast<int32_t>(Y_zero_point),
104 leading_dim,
105 in_dim,
106 out_dim);
107 }
108 }
109 }
110
quantized_matmul_out(KernelRuntimeContext & ctx,const Tensor & X,int64_t X_zero_point,const Tensor & Y,int64_t Y_zero_point,const executorch::aten::optional<Tensor> & bias,int64_t out_multiplier,int64_t out_shift,int64_t out_zero_point,bool transposed,Tensor & out)111 void quantized_matmul_out(
112 KernelRuntimeContext& ctx,
113 const Tensor& X,
114 int64_t X_zero_point,
115 const Tensor& Y,
116 int64_t Y_zero_point,
117 const executorch::aten::optional<Tensor>& bias,
118 int64_t out_multiplier,
119 int64_t out_shift,
120 int64_t out_zero_point,
121 bool transposed,
122 Tensor& out) {
123 if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
124 _typed_quantized_matmul<uint8_t>(
125 X,
126 X_zero_point,
127 Y,
128 Y_zero_point,
129 bias,
130 out_multiplier,
131 out_shift,
132 out_zero_point,
133 transposed,
134 out);
135 } else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
136 _typed_quantized_matmul<int8_t>(
137 X,
138 X_zero_point,
139 Y,
140 Y_zero_point,
141 bias,
142 out_multiplier,
143 out_shift,
144 out_zero_point,
145 transposed,
146 out);
147 } else {
148 ET_CHECK_MSG(
149 false,
150 "Unhandled input dtype %hhd",
151 static_cast<int8_t>(X.scalar_type()));
152 }
153 }
154
155 }; // namespace native
156 }; // namespace reference
157 }; // namespace impl
158