1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "eight_bit_int_gemm.h"
16
17 #include <memory>
18
19 // gemmlowp symbols should have hidden visibility.
20 // currently this is ensured in the build system by
21 // passing -finlines-visibility-hidden. TODO: it would be
22 // safer to hardcode it here with some #pragma's.
23 #include "../public/gemmlowp.h"
24
25 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
26 // code. This code path consists of a number of meta-programmed, automatically
27 // generated GEMM kernels that are suitable for some sizes of input matrices.
28 // Due to the fact that the generated code relies heavily on loop unrolling,
29 // inling and currying of runtime parameters the size of the generated binary
30 // is quite significant (approx. 200kb) which might be prohibitive in
31 // low-memory situations.
32
33 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
34 #include "../meta/multi_thread_gemm.h"
35 #endif
36
37 namespace gemmlowp {
38 namespace eight_bit_int_gemm {
39 namespace {
40
41 // To be used as template parameter for GlobalLock.
42 // GlobalLock<EightBitIntGemmLockId> is the global lock
43 // on EightBitIntGemm entry points, protecting
44 // EightBitIntGemm's global state.
45 struct EightBitIntGemmLockId;
46
47 // Global state: consists of one global GemmContext instance.
48 GemmContext* global_context;
49
GetOrCreateGlobalContext()50 GemmContext* GetOrCreateGlobalContext() {
51 if (!global_context) {
52 global_context = new GemmContext;
53 }
54 return global_context;
55 }
56
DestroyGlobalContext()57 void DestroyGlobalContext() {
58 delete global_context;
59 global_context = nullptr;
60 }
61
62 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmImpl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)63 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
64 const std::uint8_t* a, std::int32_t a_offset, int lda,
65 const std::uint8_t* b, std::int32_t b_offset, int ldb,
66 std::uint8_t* c, std::int32_t c_offset,
67 std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
68 BitDepthSetting bit_depth) {
69 const int lhs_offset = a_offset;
70 const int rhs_offset = b_offset;
71 const int result_offset = c_offset;
72 const int result_mult_int = c_mult_int;
73 const int result_shift = c_shift;
74
75 static const MapOrder ResultOrder =
76 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
77 static const MapOrder LhsOrder =
78 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
79 static const MapOrder RhsOrder =
80 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
81
82 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
83 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
84 MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
85
86 switch (bit_depth) {
87 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
88 case BitDepthSetting::BIT_DEPTH_SETTING: \
89 Gemm<std::uint8_t, BIT_DEPTH_PARAMS>( \
90 context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
91 result_mult_int, result_shift); \
92 return;
93 GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
94 GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
95 default:
96 abort();
97 #undef GEMMLOWP_HANDLE_BIT_DEPTH
98 }
99 }
100
101 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmInt32Impl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::int32_t * c,int ldc,BitDepthSetting bit_depth)102 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
103 const std::uint8_t* a, std::int32_t a_offset,
104 int lda, const std::uint8_t* b,
105 std::int32_t b_offset, int ldb, std::int32_t* c,
106 int ldc, BitDepthSetting bit_depth) {
107 const int lhs_offset = a_offset;
108 const int rhs_offset = b_offset;
109
110 static const MapOrder ResultOrder =
111 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
112 static const MapOrder LhsOrder =
113 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
114 static const MapOrder RhsOrder =
115 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
116
117 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
118 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
119 MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
120
121 auto empty_pipeline = std::make_tuple();
122
123 switch (bit_depth) {
124 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
125 case BitDepthSetting::BIT_DEPTH_SETTING: \
126 GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>( \
127 context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
128 return;
129 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
130 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
131 default:
132 abort();
133 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
134 }
135 }
136
137 class Scratch {
138 public:
Scratch()139 Scratch() : buffer_(), size_(0) {}
140
AssureSize(std::int32_t required_size)141 void AssureSize(std::int32_t required_size) {
142 if (size_ >= required_size) {
143 return;
144 }
145 buffer_.reset(new std::uint8_t[required_size]);
146 size_ = required_size;
147 }
148
Clear()149 void Clear() {
150 buffer_.reset(nullptr);
151 size_ = 0;
152 }
153
buffer()154 std::uint8_t* buffer() { return buffer_.get(); }
155
156 private:
157 std::unique_ptr<std::uint8_t[]> buffer_;
158 std::int32_t size_;
159 };
160
161 Scratch* global_scratch = nullptr;
162
GetOrCreateGlobalScratch()163 Scratch* GetOrCreateGlobalScratch() {
164 if (global_scratch == nullptr) {
165 global_scratch = new Scratch();
166 }
167 return global_scratch;
168 }
169
DestroyGlobalScratch()170 void DestroyGlobalScratch() {
171 delete global_scratch;
172 global_scratch = nullptr;
173 }
174
175 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
176
IsRowMajorOrVector(bool transpose,int stride,int rows,int cols)177 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
178 // Is it row major and nicely packed?
179 if (transpose && stride == cols) {
180 return true;
181 }
182
183 // Is it a one row vector? (a vector is both row and column major)
184 if (rows == 1) {
185 return true;
186 }
187
188 return false;
189 }
190
IsColumnMajorOrVector(bool transpose,int stride,int rows,int cols)191 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
192 // Is it column major and nicely packed?
193 if (!transpose && stride == rows) {
194 return true;
195 }
196
197 // Is it a one column vector? (a vector is both row and column major)
198 if (cols == 1) {
199 return true;
200 }
201
202 return false;
203 }
204
CanHandleMetaFastpath(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,int lda,int ldb,int ldc,BitDepthSetting depth_setting)205 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
206 int m, int n, int k, int lda, int ldb, int ldc,
207 BitDepthSetting depth_setting) {
208 // Meta fastpath only supports 8bit x 8bit and k up to 2048.
209 if (depth_setting != BitDepthSetting::A8B8 || k > 2048) {
210 return false;
211 }
212
213 // The first operand needs to be a row major matrix or a vector.
214 if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
215 return false;
216 }
217
218 // The second operand needs to be a column major matrix or a vector.
219 if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
220 return false;
221 }
222
223 // The result can either be a row major matrix, a column major matrix or
224 // a vector.
225 if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
226 return true;
227 }
228
229 if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
230 return true;
231 }
232
233 return false;
234 }
235
236 // Assure enough scratch memory is allocated and run the fast path gemm.
MetaGemmQuantized8Bit(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplicative_offset,std::int32_t shift,bool result_transpose,std::int32_t result_stride,std::uint8_t * result)237 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
238 const std::uint8_t* rhs, int m, int n, int k,
239 std::int32_t lhs_offset, std::int32_t rhs_offset,
240 std::int32_t sum_offset,
241 std::int32_t multiplicative_offset,
242 std::int32_t shift, bool result_transpose,
243 std::int32_t result_stride, std::uint8_t* result) {
244 Scratch* scratch = GetOrCreateGlobalScratch();
245 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
246 scratch->AssureSize(
247 meta::gemm_q8_scratch(m, n, k, context->max_num_threads()));
248 meta::multi_thread_gemm_q8(
249 context->workers_pool(), context->max_num_threads(), scratch->buffer(),
250 lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset,
251 multiplicative_offset, shift, result);
252 } else {
253 scratch->AssureSize(
254 meta::gemm_q8_scratch(n, m, k, context->max_num_threads()));
255 meta::multi_thread_gemm_q8(
256 context->workers_pool(), context->max_num_threads(), scratch->buffer(),
257 rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset,
258 multiplicative_offset, shift, result);
259 }
260 }
261
262 // Assure enough scratch memory is allocated and run the 8bit to float fast
263 // path gemm.
MetaGemmFloat(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,bool result_transpose,std::int32_t result_stride,float * result)264 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
265 const std::uint8_t* rhs, int m, int n, int k,
266 std::int32_t lhs_offset, std::int32_t rhs_offset,
267 float result_offset, bool result_transpose,
268 std::int32_t result_stride, float* result) {
269 Scratch* scratch = GetOrCreateGlobalScratch();
270 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
271 scratch->AssureSize(
272 meta::gemm_f_scratch(m, n, k, context->max_num_threads()));
273 meta::multi_thread_gemm_f(
274 context->workers_pool(), context->max_num_threads(), scratch->buffer(),
275 lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result);
276 } else {
277 scratch->AssureSize(
278 meta::gemm_f_scratch(n, m, k, context->max_num_threads()));
279 meta::multi_thread_gemm_f(
280 context->workers_pool(), context->max_num_threads(), scratch->buffer(),
281 rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result);
282 }
283 }
284
285 #endif
286
287 } // end anonymous namespace
288
289 // Public interface entry points
290
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)291 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
292 int m, int n, int k, const std::uint8_t* a,
293 std::int32_t a_offset, int lda, const std::uint8_t* b,
294 std::int32_t b_offset, int ldb, std::uint8_t* c,
295 std::int32_t c_offset, std::int32_t c_mult_int,
296 std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
297 AutoGlobalLock<EightBitIntGemmLockId> lock;
298 GemmContext* context = GetOrCreateGlobalContext();
299
300 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
301 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
302 ldb, ldc, bit_depth)) {
303 MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
304 c_mult_int, c_shift, transpose_c, ldc, c);
305 return;
306 }
307 #endif
308
309 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \
310 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \
311 EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b, \
312 b_offset, ldb, c, c_offset, c_mult_int, \
313 c_shift, ldc, bit_depth); \
314 }
315
316 GEMMLOWP_HANDLE_CASE(false, false, false)
317 GEMMLOWP_HANDLE_CASE(false, false, true)
318 GEMMLOWP_HANDLE_CASE(false, true, false)
319 GEMMLOWP_HANDLE_CASE(false, true, true)
320 GEMMLOWP_HANDLE_CASE(true, false, false)
321 GEMMLOWP_HANDLE_CASE(true, false, true)
322 GEMMLOWP_HANDLE_CASE(true, true, false)
323 GEMMLOWP_HANDLE_CASE(true, true, true)
324
325 #undef GEMMLOWP_HANDLE_CASE
326 }
327
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,std::int32_t lda,const std::uint8_t * b,std::int32_t b_offset,std::int32_t ldb,float * c,float c_offset,std::int32_t ldc,BitDepthSetting bit_depth)328 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
329 int m, int n, int k, const std::uint8_t* a,
330 std::int32_t a_offset, std::int32_t lda,
331 const std::uint8_t* b, std::int32_t b_offset,
332 std::int32_t ldb, float* c, float c_offset,
333 std::int32_t ldc, BitDepthSetting bit_depth) {
334 AutoGlobalLock<EightBitIntGemmLockId> lock;
335 GemmContext* context = GetOrCreateGlobalContext();
336
337 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
338 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
339 ldb, ldc, bit_depth)) {
340 MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
341 transpose_c, ldc, c);
342 return;
343 }
344 #endif
345
346 // TODO(maciekc): implement a float output stage, get rid of scratch memory.
347 Scratch* scratch = GetOrCreateGlobalScratch();
348 if (transpose_c) {
349 scratch->AssureSize(m * ldc * sizeof(std::int32_t));
350 } else {
351 scratch->AssureSize(n * ldc * sizeof(std::int32_t));
352 }
353 std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
354
355 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \
356 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \
357 EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
358 b, b_offset, ldb, temp_c, ldc, \
359 bit_depth); \
360 }
361
362 GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
363 GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
364 GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
365 GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
366 GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
367 GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
368 GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
369 GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
370
371 #undef GEMMLOWP_HANDLE_INT32_CASE
372
373 if (transpose_c) {
374 // Row major.
375 for (int i = 0; i < m; ++i) {
376 float* dest_row = c + i * ldc;
377 std::int32_t* src_row = temp_c + i * ldc;
378 for (int j = 0; j < n; ++j) {
379 dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
380 }
381 }
382 } else {
383 // Column major.
384 for (int i = 0; i < n; ++i) {
385 float* dest_column = c + i * ldc;
386 std::int32_t* src_column = temp_c + i * ldc;
387 for (int j = 0; j < m; ++j) {
388 dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
389 }
390 }
391 }
392 }
393
SetMaxNumThreads(int n)394 void SetMaxNumThreads(int n) {
395 AutoGlobalLock<EightBitIntGemmLockId> lock;
396 GemmContext* context = GetOrCreateGlobalContext();
397 context->set_max_num_threads(n);
398 }
399
FreePersistentResources()400 void FreePersistentResources() {
401 AutoGlobalLock<EightBitIntGemmLockId> lock;
402 DestroyGlobalContext();
403 DestroyGlobalScratch();
404 }
405
406 } // namespace eight_bit_int_gemm
407 } // namespace gemmlowp
408