1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "eight_bit_int_gemm.h"
16
17 #include <memory>
18
19 // gemmlowp symbols should have hidden visibility.
20 // currently this is ensured in the build system by
21 // passing -finlines-visibility-hidden. TODO: it would be
22 // safer to hardcode it here with some #pragma's.
23 #include "../public/gemmlowp.h"
24
25 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
26 // code. This code path consists of a number of meta-programmed, automatically
27 // generated GEMM kernels that are suitable for some sizes of input matrices.
28 // Due to the fact that the generated code relies heavily on loop unrolling,
29 // inling and currying of runtime parameters the size of the generated binary
30 // is quite significant (approx. 200kb) which might be prohibitive in
31 // low-memory situations.
32
33 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
34 #include "../meta/legacy_multi_thread_gemm.h"
35 #else
36
37 #if defined(GEMMLOWP_USE_META_FASTPATH)
38 #warning "META fast path turned on without NEON!"
39 #endif
40
41 #endif
42
43 namespace gemmlowp {
44 namespace eight_bit_int_gemm {
45 namespace {
46
47 // To be used as template parameter for GlobalLock.
48 // GlobalLock<EightBitIntGemmLockId> is the global lock
49 // on EightBitIntGemm entry points, protecting
50 // EightBitIntGemm's global state.
51 struct EightBitIntGemmLockId;
52
53 // Global state: consists of one global GemmContext instance.
54 GemmContext* global_context;
55
GetOrCreateGlobalContext()56 GemmContext* GetOrCreateGlobalContext() {
57 if (!global_context) {
58 global_context = new GemmContext;
59 }
60 return global_context;
61 }
62
DestroyGlobalContext()63 void DestroyGlobalContext() {
64 delete global_context;
65 global_context = nullptr;
66 }
67
68 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmImpl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)69 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
70 const std::uint8_t* a, std::int32_t a_offset, int lda,
71 const std::uint8_t* b, std::int32_t b_offset, int ldb,
72 std::uint8_t* c, std::int32_t c_offset,
73 std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
74 BitDepthSetting bit_depth) {
75 const int lhs_offset = a_offset;
76 const int rhs_offset = b_offset;
77 const int result_offset = c_offset;
78 const int result_mult_int = c_mult_int;
79 const int result_shift = c_shift;
80
81 static const MapOrder ResultOrder =
82 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
83 static const MapOrder LhsOrder =
84 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
85 static const MapOrder RhsOrder =
86 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
87
88 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
89 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
90 MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
91
92 switch (bit_depth) {
93 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
94 case BitDepthSetting::BIT_DEPTH_SETTING: \
95 Gemm<std::uint8_t, BIT_DEPTH_PARAMS>( \
96 context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
97 result_mult_int, result_shift); \
98 return;
99 GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
100 GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
101 default:
102 abort();
103 #undef GEMMLOWP_HANDLE_BIT_DEPTH
104 }
105 }
106
107 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmInt32Impl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::int32_t * c,int ldc,BitDepthSetting bit_depth)108 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
109 const std::uint8_t* a, std::int32_t a_offset,
110 int lda, const std::uint8_t* b,
111 std::int32_t b_offset, int ldb, std::int32_t* c,
112 int ldc, BitDepthSetting bit_depth) {
113 const int lhs_offset = a_offset;
114 const int rhs_offset = b_offset;
115
116 static const MapOrder ResultOrder =
117 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
118 static const MapOrder LhsOrder =
119 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
120 static const MapOrder RhsOrder =
121 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
122
123 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
124 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
125 MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
126
127 auto empty_pipeline = std::make_tuple();
128
129 switch (bit_depth) {
130 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
131 case BitDepthSetting::BIT_DEPTH_SETTING: \
132 GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>( \
133 context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
134 return;
135 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
136 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
137 default:
138 abort();
139 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
140 }
141 }
142
143 class Scratch {
144 public:
Scratch()145 Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {}
146
AssureSize(std::int32_t required_size)147 void AssureSize(std::int32_t required_size) {
148 if (size_ >= required_size) {
149 return;
150 }
151 buffer_.reset(new std::uint8_t[required_size + 32]);
152 buffer_32_ =
153 buffer_.get() +
154 ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32);
155 assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0);
156 size_ = required_size;
157 }
158
Clear()159 void Clear() {
160 buffer_.reset(nullptr);
161 buffer_32_ = nullptr;
162 size_ = 0;
163 }
164
buffer()165 std::uint8_t* buffer() { return buffer_32_; }
166
167 private:
168 std::unique_ptr<std::uint8_t[]> buffer_;
169 std::uint8_t* buffer_32_;
170 std::int32_t size_;
171 };
172
173 Scratch* global_scratch = nullptr;
174
GetOrCreateGlobalScratch()175 Scratch* GetOrCreateGlobalScratch() {
176 if (global_scratch == nullptr) {
177 global_scratch = new Scratch();
178 }
179 return global_scratch;
180 }
181
DestroyGlobalScratch()182 void DestroyGlobalScratch() {
183 delete global_scratch;
184 global_scratch = nullptr;
185 }
186
187 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
188
IsRowMajorOrVector(bool transpose,int stride,int rows,int cols)189 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
190 // Is it row major and nicely packed?
191 if (transpose && stride == cols) {
192 return true;
193 }
194
195 // Is it a one row vector? (a vector is both row and column major)
196 if (rows == 1) {
197 return true;
198 }
199
200 return false;
201 }
202
IsColumnMajorOrVector(bool transpose,int stride,int rows,int cols)203 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
204 // Is it column major and nicely packed?
205 if (!transpose && stride == rows) {
206 return true;
207 }
208
209 // Is it a one column vector? (a vector is both row and column major)
210 if (cols == 1) {
211 return true;
212 }
213
214 return false;
215 }
216
CanHandleMetaFastpath(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,int lda,int ldb,int ldc,BitDepthSetting depth_setting)217 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
218 int m, int n, int k, int lda, int ldb, int ldc,
219 BitDepthSetting depth_setting) {
220 // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048.
221 if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) {
222 return false;
223 }
224
225 // The first operand needs to be a row major matrix or a vector.
226 if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
227 return false;
228 }
229
230 // The second operand needs to be a column major matrix or a vector.
231 if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
232 return false;
233 }
234
235 // The result can either be a row major matrix, a column major matrix or
236 // a vector.
237 if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
238 return true;
239 }
240
241 if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
242 return true;
243 }
244
245 return false;
246 }
247
248 // Assure enough scratch memory is allocated and run the fast path gemm.
MetaGemmQuantized8Bit(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplicative_offset,std::int32_t shift,bool result_transpose,std::int32_t result_stride,std::uint8_t * result)249 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
250 const std::uint8_t* rhs, int m, int n, int k,
251 std::int32_t lhs_offset, std::int32_t rhs_offset,
252 std::int32_t sum_offset,
253 std::int32_t multiplicative_offset,
254 std::int32_t shift, bool result_transpose,
255 std::int32_t result_stride, std::uint8_t* result) {
256 Scratch* scratch = GetOrCreateGlobalScratch();
257 const std::int32_t max_num_threads = context->max_num_threads();
258 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
259 scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads));
260 meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
261 scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
262 rhs_offset, sum_offset, multiplicative_offset,
263 shift, result);
264 } else {
265 scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads));
266 meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
267 scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
268 lhs_offset, sum_offset, multiplicative_offset,
269 shift, result);
270 }
271 }
272
273 // Assure enough scratch memory is allocated and run the 8bit to float fast
274 // path gemm.
MetaGemmFloat(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,bool result_transpose,std::int32_t result_stride,float * result)275 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
276 const std::uint8_t* rhs, int m, int n, int k,
277 std::int32_t lhs_offset, std::int32_t rhs_offset,
278 float result_offset, bool result_transpose,
279 std::int32_t result_stride, float* result) {
280 Scratch* scratch = GetOrCreateGlobalScratch();
281 const std::int32_t max_num_threads = context->max_num_threads();
282 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
283 scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads));
284 meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
285 scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
286 rhs_offset, result_offset, result);
287 } else {
288 scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads));
289 meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
290 scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
291 lhs_offset, result_offset, result);
292 }
293 }
294
295 #endif
296
297 } // end anonymous namespace
298
299 // Public interface entry points
300
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)301 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
302 int m, int n, int k, const std::uint8_t* a,
303 std::int32_t a_offset, int lda, const std::uint8_t* b,
304 std::int32_t b_offset, int ldb, std::uint8_t* c,
305 std::int32_t c_offset, std::int32_t c_mult_int,
306 std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
307 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
308 GemmContext* context = GetOrCreateGlobalContext();
309
310 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
311 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
312 ldb, ldc, bit_depth)) {
313 MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
314 c_mult_int, c_shift, transpose_c, ldc, c);
315 return;
316 }
317 #endif
318
319 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \
320 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \
321 EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b, \
322 b_offset, ldb, c, c_offset, c_mult_int, \
323 c_shift, ldc, bit_depth); \
324 }
325
326 GEMMLOWP_HANDLE_CASE(false, false, false)
327 GEMMLOWP_HANDLE_CASE(false, false, true)
328 GEMMLOWP_HANDLE_CASE(false, true, false)
329 GEMMLOWP_HANDLE_CASE(false, true, true)
330 GEMMLOWP_HANDLE_CASE(true, false, false)
331 GEMMLOWP_HANDLE_CASE(true, false, true)
332 GEMMLOWP_HANDLE_CASE(true, true, false)
333 GEMMLOWP_HANDLE_CASE(true, true, true)
334
335 #undef GEMMLOWP_HANDLE_CASE
336 }
337
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,std::int32_t lda,const std::uint8_t * b,std::int32_t b_offset,std::int32_t ldb,float * c,float c_offset,std::int32_t ldc,BitDepthSetting bit_depth)338 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
339 int m, int n, int k, const std::uint8_t* a,
340 std::int32_t a_offset, std::int32_t lda,
341 const std::uint8_t* b, std::int32_t b_offset,
342 std::int32_t ldb, float* c, float c_offset,
343 std::int32_t ldc, BitDepthSetting bit_depth) {
344 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
345 GemmContext* context = GetOrCreateGlobalContext();
346
347 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
348 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
349 ldb, ldc, bit_depth)) {
350 MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
351 transpose_c, ldc, c);
352 return;
353 }
354 #endif
355
356 // TODO(maciekc): implement a float output stage, get rid of scratch memory.
357 Scratch* scratch = GetOrCreateGlobalScratch();
358 if (transpose_c) {
359 scratch->AssureSize(m * ldc * sizeof(std::int32_t));
360 } else {
361 scratch->AssureSize(n * ldc * sizeof(std::int32_t));
362 }
363 std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
364
365 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \
366 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \
367 EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
368 b, b_offset, ldb, temp_c, ldc, \
369 bit_depth); \
370 }
371
372 GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
373 GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
374 GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
375 GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
376 GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
377 GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
378 GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
379 GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
380
381 #undef GEMMLOWP_HANDLE_INT32_CASE
382
383 if (transpose_c) {
384 // Row major.
385 for (int i = 0; i < m; ++i) {
386 float* dest_row = c + i * ldc;
387 std::int32_t* src_row = temp_c + i * ldc;
388 for (int j = 0; j < n; ++j) {
389 dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
390 }
391 }
392 } else {
393 // Column major.
394 for (int i = 0; i < n; ++i) {
395 float* dest_column = c + i * ldc;
396 std::int32_t* src_column = temp_c + i * ldc;
397 for (int j = 0; j < m; ++j) {
398 dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
399 }
400 }
401 }
402 }
403
SetMaxNumThreads(int n)404 void SetMaxNumThreads(int n) {
405 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
406 GemmContext* context = GetOrCreateGlobalContext();
407 context->set_max_num_threads(n);
408 }
409
FreePersistentResources()410 void FreePersistentResources() {
411 ScopedLock sl(GlobalMutexes::EightBitIntGemm());
412 DestroyGlobalContext();
413 DestroyGlobalScratch();
414 }
415
416 } // namespace eight_bit_int_gemm
417 } // namespace gemmlowp
418