1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
16 #define GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
17
18 #include "../internal/common.h"
19
20 #ifdef GEMMLOWP_NEON
21
22 #include "quantized_mul_kernels.h"
23 #include "single_thread_gemm.h"
24 #include "streams.h"
25
26 namespace gemmlowp {
27 namespace meta {
28
gemm_q8_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t result_offset,std::int32_t multiplicative_offset,std::int32_t shift,std::uint8_t * result,std::int32_t result_stride)29 void gemm_q8_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
30 const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
31 std::int32_t k, std::int32_t lhs_offset,
32 std::int32_t rhs_offset, std::int32_t result_offset,
33 std::int32_t multiplicative_offset, std::int32_t shift,
34 std::uint8_t* result, std::int32_t result_stride) {
35 #ifdef DEBUG
36 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
37 std::cout << "Legacy::GemmQ8." << std::endl;
38 #endif
39 #endif
40 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
41 RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
42 Params;
43 Params params;
44
45 params.m = m;
46 params.n = n;
47 params.k = k;
48
49 params.lhs = lhs;
50 params.rhs = rhs;
51 params.result = result;
52 params.scratch = scratch;
53
54 params.left_stream.count = k;
55 params.left_stream.stride = k;
56 params.left_stream.multiplicative_sum_offset = rhs_offset;
57 params.left_stream.additive_sum_offset =
58 result_offset + k * lhs_offset * rhs_offset;
59
60 params.right_stream.count = k;
61 params.right_stream.stride = k;
62 params.right_stream.multiplicative_sum_offset = lhs_offset;
63 params.right_stream.additive_sum_offset = 0;
64
65 params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset;
66 params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1));
67 params.fused_kernel.kernel.shift = -shift;
68 params.fused_kernel.kernel.count = k;
69 params.fused_kernel.output_stream.stride = result_stride;
70
71 Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
72 }
73
gemv_q8(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t result_offset,std::int32_t multiplicative_offset,std::int32_t shift,std::uint8_t * result)74 void gemv_q8(std::uint8_t* scratch, const std::uint8_t* lhs,
75 const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
76 std::int32_t lhs_offset, std::int32_t rhs_offset,
77 std::int32_t result_offset, std::int32_t multiplicative_offset,
78 std::int32_t shift, std::uint8_t* result) {
79 #ifdef DEBUG
80 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
81 std::cout << "Legacy::GemvQ8." << std::endl;
82 #endif
83 #endif
84 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum,
85 RowMajorWithSum, QuantizedStaticPreprocessed, RowMajor>
86 Params;
87 Params params;
88
89 params.m = 1;
90 params.n = n;
91 params.k = k;
92
93 params.lhs = lhs;
94 params.rhs = rhs;
95 params.result = result;
96 params.scratch = scratch;
97
98 params.left_stream.count = k;
99 params.left_stream.stride = k;
100 params.left_stream.multiplicative_sum_offset = rhs_offset;
101 params.left_stream.additive_sum_offset =
102 result_offset + k * lhs_offset * rhs_offset;
103
104 params.right_stream.count = k;
105 params.right_stream.stride = k;
106 params.right_stream.multiplicative_sum_offset = lhs_offset;
107 params.right_stream.additive_sum_offset = 0;
108
109 params.fused_kernel.kernel.multiplicative_offset = multiplicative_offset;
110 params.fused_kernel.kernel.rounding_offset = (1 << (shift - 1));
111 params.fused_kernel.kernel.shift = -shift;
112 params.fused_kernel.kernel.count = k;
113 params.fused_kernel.output_stream.stride = n;
114
115 if (k < 1536) {
116 Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
117 } else {
118 Gemm<GemmExecutorPackLHS, Params, 2, 4, 8>(params);
119 }
120 }
121
gemm_i32_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result,std::int32_t result_stride)122 void gemm_i32_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
123 const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
124 std::int32_t k, std::int32_t lhs_offset,
125 std::int32_t rhs_offset, std::int32_t* result,
126 std::int32_t result_stride) {
127 #ifdef DEBUG
128 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
129 std::cout << "Legacy::GemmI32." << std::endl;
130 #endif
131 #endif
132 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
133 RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
134 RowMajor>
135 Params;
136 Params params;
137
138 params.m = m;
139 params.n = n;
140 params.k = k;
141
142 params.lhs = lhs;
143 params.rhs = rhs;
144 params.result = result;
145 params.scratch = scratch;
146
147 params.left_stream.count = k;
148 params.left_stream.stride = k;
149 params.left_stream.multiplicative_sum_offset = rhs_offset;
150 params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
151
152 params.right_stream.count = k;
153 params.right_stream.stride = k;
154 params.right_stream.multiplicative_sum_offset = lhs_offset;
155 params.right_stream.additive_sum_offset = 0;
156
157 params.fused_kernel.kernel.count = k;
158 params.fused_kernel.output_stream.stride = result_stride * 4;
159
160 Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
161 }
162
gemv_i32(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)163 void gemv_i32(std::uint8_t* scratch, const std::uint8_t* lhs,
164 const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
165 std::int32_t lhs_offset, std::int32_t rhs_offset,
166 std::int32_t* result) {
167 #ifdef DEBUG
168 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
169 std::cout << "Legacy::GemvI32." << std::endl;
170 #endif
171 #endif
172 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum,
173 RowMajorWithSum, QuantizedStaticPreprocessedAsInt32,
174 RowMajor>
175 Params;
176 Params params;
177
178 params.m = 1;
179 params.n = n;
180 params.k = k;
181
182 params.lhs = lhs;
183 params.rhs = rhs;
184 params.result = result;
185 params.scratch = scratch;
186
187 params.left_stream.count = k;
188 params.left_stream.stride = k;
189 params.left_stream.multiplicative_sum_offset = rhs_offset;
190 params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
191
192 params.right_stream.count = k;
193 params.right_stream.stride = k;
194 params.right_stream.multiplicative_sum_offset = lhs_offset;
195 params.right_stream.additive_sum_offset = 0;
196
197 params.fused_kernel.kernel.count = k;
198 params.fused_kernel.output_stream.stride = 0;
199
200 if (k < 1664) {
201 Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
202 } else {
203 Gemm<GemmExecutorPackLHS, Params, 1, 6, 8>(params);
204 }
205 }
206
gemm_f_strided(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result,std::int32_t result_stride)207 void gemm_f_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
208 const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
209 std::int32_t k, std::int32_t lhs_offset,
210 std::int32_t rhs_offset, float result_offset, float* result,
211 std::int32_t result_stride) {
212 #ifdef DEBUG
213 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
214 std::cout << "Legacy::GemmF." << std::endl;
215 #endif
216 #endif
217 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
218 QuantizedStaticPreprocessedAsFloat, RowMajor>
219 Params;
220 Params params;
221
222 params.m = m;
223 params.n = n;
224 params.k = k;
225
226 params.lhs = lhs;
227 params.rhs = rhs;
228 params.result = result;
229 params.scratch = scratch;
230
231 params.left_stream.count = k;
232 params.left_stream.stride = k;
233 params.left_stream.multiplicative_sum_offset = rhs_offset;
234 params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
235
236 params.right_stream.count = k;
237 params.right_stream.stride = k;
238 params.right_stream.multiplicative_sum_offset = lhs_offset;
239 params.right_stream.additive_sum_offset = 0;
240
241 params.fused_kernel.kernel.count = k;
242 params.fused_kernel.kernel.scale = result_offset;
243 params.fused_kernel.output_stream.stride = result_stride * 4;
244
245 Gemm<GemmExecutorPackRHS, Params, 2, 4, 8>(params);
246 }
247
gemv_f(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)248 void gemv_f(std::uint8_t* scratch, const std::uint8_t* lhs,
249 const std::uint8_t* rhs, std::int32_t n, std::int32_t k,
250 std::int32_t lhs_offset, std::int32_t rhs_offset,
251 float result_offset, float* result) {
252 #ifdef DEBUG
253 #ifdef DEBUG_METAGEMM_LEGACY_VERBOSE
254 std::cout << "Legacy::GemvF." << std::endl;
255 #endif
256 #endif
257 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
258 QuantizedStaticPreprocessedAsFloat, RowMajor>
259 Params;
260 Params params;
261
262 params.m = 1;
263 params.n = n;
264 params.k = k;
265
266 params.lhs = lhs;
267 params.rhs = rhs;
268 params.result = result;
269 params.scratch = scratch;
270
271 params.left_stream.count = k;
272 params.left_stream.stride = k;
273 params.left_stream.multiplicative_sum_offset = rhs_offset;
274 params.left_stream.additive_sum_offset = k * lhs_offset * rhs_offset;
275
276 params.right_stream.count = k;
277 params.right_stream.stride = k;
278 params.right_stream.multiplicative_sum_offset = lhs_offset;
279 params.right_stream.additive_sum_offset = 0;
280
281 params.fused_kernel.kernel.count = k;
282 params.fused_kernel.kernel.scale = result_offset;
283 params.fused_kernel.output_stream.stride = 0;
284
285 if (k < 1664) {
286 Gemm<GemmExecutorPackLHS, Params, 1, 8, 8>(params);
287 } else {
288 Gemm<GemmExecutorPackLHS, Params, 1, 6, 8>(params);
289 }
290 }
291
292 } // namespace meta
293 } // namespace gemmlowp
294
295 #else
296 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
297 #endif
298
299 #endif // GEMMLOWP_META_LEGACY_SINGLE_THREAD_GEMM_H_
300