• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "eight_bit_int_gemm.h"
16 
17 #include <memory>
18 
19 // gemmlowp symbols should have hidden visibility.
20 // currently this is ensured in the build system by
21 // passing -finlines-visibility-hidden. TODO: it would be
22 // safer to hardcode it here with some #pragma's.
23 #include "../public/gemmlowp.h"
24 
25 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
26 // code. This code path consists of a number of meta-programmed, automatically
27 // generated GEMM kernels that are suitable for some sizes of input matrices.
28 // Due to the fact that the generated code relies heavily on loop unrolling,
29 // inling and currying of runtime parameters the size of the generated binary
30 // is quite significant (approx. 200kb) which might be prohibitive in
31 // low-memory situations.
32 
33 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
34 #include "../meta/legacy_multi_thread_gemm.h"
35 #else
36 
37 #if defined(GEMMLOWP_USE_META_FASTPATH)
38 #warning "META fast path turned on without NEON!"
39 #endif
40 
41 #endif
42 
43 namespace gemmlowp {
44 namespace eight_bit_int_gemm {
45 namespace {
46 
47 // To be used as template parameter for GlobalLock.
48 // GlobalLock<EightBitIntGemmLockId> is the global lock
49 // on EightBitIntGemm entry points, protecting
50 // EightBitIntGemm's global state.
51 struct EightBitIntGemmLockId;
52 
53 // Global state: consists of one global GemmContext instance.
54 GemmContext* global_context;
55 
GetOrCreateGlobalContext()56 GemmContext* GetOrCreateGlobalContext() {
57   if (!global_context) {
58     global_context = new GemmContext;
59   }
60   return global_context;
61 }
62 
DestroyGlobalContext()63 void DestroyGlobalContext() {
64   delete global_context;
65   global_context = nullptr;
66 }
67 
68 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmImpl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)69 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
70                          const std::uint8_t* a, std::int32_t a_offset, int lda,
71                          const std::uint8_t* b, std::int32_t b_offset, int ldb,
72                          std::uint8_t* c, std::int32_t c_offset,
73                          std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
74                          BitDepthSetting bit_depth) {
75   const int lhs_offset = a_offset;
76   const int rhs_offset = b_offset;
77   const int result_offset = c_offset;
78   const int result_mult_int = c_mult_int;
79   const int result_shift = c_shift;
80 
81   static const MapOrder ResultOrder =
82       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
83   static const MapOrder LhsOrder =
84       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
85   static const MapOrder RhsOrder =
86       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
87 
88   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
89   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
90   MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
91 
92   switch (bit_depth) {
93 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
94   case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
95     Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
96         context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
97         result_mult_int, result_shift);                                    \
98     return;
99     GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
100     GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
101     default:
102       abort();
103 #undef GEMMLOWP_HANDLE_BIT_DEPTH
104   }
105 }
106 
107 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmInt32Impl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::int32_t * c,int ldc,BitDepthSetting bit_depth)108 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
109                               const std::uint8_t* a, std::int32_t a_offset,
110                               int lda, const std::uint8_t* b,
111                               std::int32_t b_offset, int ldb, std::int32_t* c,
112                               int ldc, BitDepthSetting bit_depth) {
113   const int lhs_offset = a_offset;
114   const int rhs_offset = b_offset;
115 
116   static const MapOrder ResultOrder =
117       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
118   static const MapOrder LhsOrder =
119       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
120   static const MapOrder RhsOrder =
121       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
122 
123   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
124   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
125   MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
126 
127   auto empty_pipeline = std::make_tuple();
128 
129   switch (bit_depth) {
130 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
131   case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
132     GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
133         context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
134     return;
135     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
136     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
137     default:
138       abort();
139 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
140   }
141 }
142 
143 class Scratch {
144  public:
Scratch()145   Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {}
146 
AssureSize(std::int32_t required_size)147   void AssureSize(std::int32_t required_size) {
148     if (size_ >= required_size) {
149       return;
150     }
151     buffer_.reset(new std::uint8_t[required_size + 32]);
152     buffer_32_ =
153         buffer_.get() +
154         ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32);
155     assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0);
156     size_ = required_size;
157   }
158 
Clear()159   void Clear() {
160     buffer_.reset(nullptr);
161     buffer_32_ = nullptr;
162     size_ = 0;
163   }
164 
buffer()165   std::uint8_t* buffer() { return buffer_32_; }
166 
167  private:
168   std::unique_ptr<std::uint8_t[]> buffer_;
169   std::uint8_t* buffer_32_;
170   std::int32_t size_;
171 };
172 
173 Scratch* global_scratch = nullptr;
174 
GetOrCreateGlobalScratch()175 Scratch* GetOrCreateGlobalScratch() {
176   if (global_scratch == nullptr) {
177     global_scratch = new Scratch();
178   }
179   return global_scratch;
180 }
181 
DestroyGlobalScratch()182 void DestroyGlobalScratch() {
183   delete global_scratch;
184   global_scratch = nullptr;
185 }
186 
187 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
188 
IsRowMajorOrVector(bool transpose,int stride,int rows,int cols)189 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
190   // Is it row major and nicely packed?
191   if (transpose && stride == cols) {
192     return true;
193   }
194 
195   // Is it a one row vector? (a vector is both row and column major)
196   if (rows == 1) {
197     return true;
198   }
199 
200   return false;
201 }
202 
IsColumnMajorOrVector(bool transpose,int stride,int rows,int cols)203 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
204   // Is it column major and nicely packed?
205   if (!transpose && stride == rows) {
206     return true;
207   }
208 
209   // Is it a one column vector? (a vector is both row and column major)
210   if (cols == 1) {
211     return true;
212   }
213 
214   return false;
215 }
216 
CanHandleMetaFastpath(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,int lda,int ldb,int ldc,BitDepthSetting depth_setting)217 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
218                            int m, int n, int k, int lda, int ldb, int ldc,
219                            BitDepthSetting depth_setting) {
220   // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048.
221   if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) {
222     return false;
223   }
224 
225   // The first operand needs to be a row major matrix or a vector.
226   if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
227     return false;
228   }
229 
230   // The second operand needs to be a column major matrix or a vector.
231   if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
232     return false;
233   }
234 
235   // The result can either be a row major matrix, a column major matrix or
236   // a vector.
237   if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
238     return true;
239   }
240 
241   if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
242     return true;
243   }
244 
245   return false;
246 }
247 
248 // Assure enough scratch memory is allocated and run the fast path gemm.
MetaGemmQuantized8Bit(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplicative_offset,std::int32_t shift,bool result_transpose,std::int32_t result_stride,std::uint8_t * result)249 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
250                            const std::uint8_t* rhs, int m, int n, int k,
251                            std::int32_t lhs_offset, std::int32_t rhs_offset,
252                            std::int32_t sum_offset,
253                            std::int32_t multiplicative_offset,
254                            std::int32_t shift, bool result_transpose,
255                            std::int32_t result_stride, std::uint8_t* result) {
256   Scratch* scratch = GetOrCreateGlobalScratch();
257   const std::int32_t max_num_threads = context->max_num_threads();
258   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
259     scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads));
260     meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
261                                scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
262                                rhs_offset, sum_offset, multiplicative_offset,
263                                shift, result);
264   } else {
265     scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads));
266     meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
267                                scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
268                                lhs_offset, sum_offset, multiplicative_offset,
269                                shift, result);
270   }
271 }
272 
273 // Assure enough scratch memory is allocated and run the 8bit to float fast
274 // path gemm.
MetaGemmFloat(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,bool result_transpose,std::int32_t result_stride,float * result)275 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
276                    const std::uint8_t* rhs, int m, int n, int k,
277                    std::int32_t lhs_offset, std::int32_t rhs_offset,
278                    float result_offset, bool result_transpose,
279                    std::int32_t result_stride, float* result) {
280   Scratch* scratch = GetOrCreateGlobalScratch();
281   const std::int32_t max_num_threads = context->max_num_threads();
282   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
283     scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads));
284     meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
285                               scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
286                               rhs_offset, result_offset, result);
287   } else {
288     scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads));
289     meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
290                               scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
291                               lhs_offset, result_offset, result);
292   }
293 }
294 
295 #endif
296 
297 }  // end anonymous namespace
298 
299 // Public interface entry points
300 
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)301 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
302                      int m, int n, int k, const std::uint8_t* a,
303                      std::int32_t a_offset, int lda, const std::uint8_t* b,
304                      std::int32_t b_offset, int ldb, std::uint8_t* c,
305                      std::int32_t c_offset, std::int32_t c_mult_int,
306                      std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
307   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
308   GemmContext* context = GetOrCreateGlobalContext();
309 
310 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
311   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
312                             ldb, ldc, bit_depth)) {
313     MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
314                           c_mult_int, c_shift, transpose_c, ldc, c);
315     return;
316   }
317 #endif
318 
319 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
320   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
321     EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
322                                     b_offset, ldb, c, c_offset, c_mult_int, \
323                                     c_shift, ldc, bit_depth);               \
324   }
325 
326   GEMMLOWP_HANDLE_CASE(false, false, false)
327   GEMMLOWP_HANDLE_CASE(false, false, true)
328   GEMMLOWP_HANDLE_CASE(false, true, false)
329   GEMMLOWP_HANDLE_CASE(false, true, true)
330   GEMMLOWP_HANDLE_CASE(true, false, false)
331   GEMMLOWP_HANDLE_CASE(true, false, true)
332   GEMMLOWP_HANDLE_CASE(true, true, false)
333   GEMMLOWP_HANDLE_CASE(true, true, true)
334 
335 #undef GEMMLOWP_HANDLE_CASE
336 }
337 
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,std::int32_t lda,const std::uint8_t * b,std::int32_t b_offset,std::int32_t ldb,float * c,float c_offset,std::int32_t ldc,BitDepthSetting bit_depth)338 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
339                      int m, int n, int k, const std::uint8_t* a,
340                      std::int32_t a_offset, std::int32_t lda,
341                      const std::uint8_t* b, std::int32_t b_offset,
342                      std::int32_t ldb, float* c, float c_offset,
343                      std::int32_t ldc, BitDepthSetting bit_depth) {
344   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
345   GemmContext* context = GetOrCreateGlobalContext();
346 
347 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
348   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
349                             ldb, ldc, bit_depth)) {
350     MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
351                   transpose_c, ldc, c);
352     return;
353   }
354 #endif
355 
356   // TODO(maciekc): implement a float output stage, get rid of scratch memory.
357   Scratch* scratch = GetOrCreateGlobalScratch();
358   if (transpose_c) {
359     scratch->AssureSize(m * ldc * sizeof(std::int32_t));
360   } else {
361     scratch->AssureSize(n * ldc * sizeof(std::int32_t));
362   }
363   std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
364 
365 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
366   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
367     EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
368                                          b, b_offset, ldb, temp_c, ldc,      \
369                                          bit_depth);                         \
370   }
371 
372   GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
373   GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
374   GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
375   GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
376   GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
377   GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
378   GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
379   GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
380 
381 #undef GEMMLOWP_HANDLE_INT32_CASE
382 
383   if (transpose_c) {
384     // Row major.
385     for (int i = 0; i < m; ++i) {
386       float* dest_row = c + i * ldc;
387       std::int32_t* src_row = temp_c + i * ldc;
388       for (int j = 0; j < n; ++j) {
389         dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
390       }
391     }
392   } else {
393     // Column major.
394     for (int i = 0; i < n; ++i) {
395       float* dest_column = c + i * ldc;
396       std::int32_t* src_column = temp_c + i * ldc;
397       for (int j = 0; j < m; ++j) {
398         dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
399       }
400     }
401   }
402 }
403 
SetMaxNumThreads(int n)404 void SetMaxNumThreads(int n) {
405   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
406   GemmContext* context = GetOrCreateGlobalContext();
407   context->set_max_num_threads(n);
408 }
409 
FreePersistentResources()410 void FreePersistentResources() {
411   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
412   DestroyGlobalContext();
413   DestroyGlobalScratch();
414 }
415 
416 }  // namespace eight_bit_int_gemm
417 }  // namespace gemmlowp
418