• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "eight_bit_int_gemm.h"
16 
17 #include <memory>
18 
19 // gemmlowp symbols should have hidden visibility.
20 // currently this is ensured in the build system by
21 // passing -finlines-visibility-hidden. TODO: it would be
22 // safer to hardcode it here with some #pragma's.
23 #include "../public/gemmlowp.h"
24 
25 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
26 // code. This code path consists of a number of meta-programmed, automatically
27 // generated GEMM kernels that are suitable for some sizes of input matrices.
28 // Due to the fact that the generated code relies heavily on loop unrolling,
29 // inling and currying of runtime parameters the size of the generated binary
30 // is quite significant (approx. 200kb) which might be prohibitive in
31 // low-memory situations.
32 
33 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
34 #include "../meta/multi_thread_gemm.h"
35 #endif
36 
37 namespace gemmlowp {
38 namespace eight_bit_int_gemm {
39 namespace {
40 
41 // To be used as template parameter for GlobalLock.
42 // GlobalLock<EightBitIntGemmLockId> is the global lock
43 // on EightBitIntGemm entry points, protecting
44 // EightBitIntGemm's global state.
45 struct EightBitIntGemmLockId;
46 
47 // Global state: consists of one global GemmContext instance.
48 GemmContext* global_context;
49 
GetOrCreateGlobalContext()50 GemmContext* GetOrCreateGlobalContext() {
51   if (!global_context) {
52     global_context = new GemmContext;
53   }
54   return global_context;
55 }
56 
DestroyGlobalContext()57 void DestroyGlobalContext() {
58   delete global_context;
59   global_context = nullptr;
60 }
61 
62 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmImpl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)63 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
64                          const std::uint8_t* a, std::int32_t a_offset, int lda,
65                          const std::uint8_t* b, std::int32_t b_offset, int ldb,
66                          std::uint8_t* c, std::int32_t c_offset,
67                          std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
68                          BitDepthSetting bit_depth) {
69   const int lhs_offset = a_offset;
70   const int rhs_offset = b_offset;
71   const int result_offset = c_offset;
72   const int result_mult_int = c_mult_int;
73   const int result_shift = c_shift;
74 
75   static const MapOrder ResultOrder =
76       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
77   static const MapOrder LhsOrder =
78       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
79   static const MapOrder RhsOrder =
80       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
81 
82   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
83   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
84   MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
85 
86   switch (bit_depth) {
87 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
88   case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
89     Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
90         context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
91         result_mult_int, result_shift);                                    \
92     return;
93     GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
94     GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
95     default:
96       abort();
97 #undef GEMMLOWP_HANDLE_BIT_DEPTH
98   }
99 }
100 
101 template <bool transpose_a, bool transpose_b, bool transpose_c>
EightBitIntGemmInt32Impl(GemmContext * context,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::int32_t * c,int ldc,BitDepthSetting bit_depth)102 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
103                               const std::uint8_t* a, std::int32_t a_offset,
104                               int lda, const std::uint8_t* b,
105                               std::int32_t b_offset, int ldb, std::int32_t* c,
106                               int ldc, BitDepthSetting bit_depth) {
107   const int lhs_offset = a_offset;
108   const int rhs_offset = b_offset;
109 
110   static const MapOrder ResultOrder =
111       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
112   static const MapOrder LhsOrder =
113       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
114   static const MapOrder RhsOrder =
115       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
116 
117   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
118   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
119   MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
120 
121   auto empty_pipeline = std::make_tuple();
122 
123   switch (bit_depth) {
124 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
125   case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
126     GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
127         context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
128     return;
129     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
130     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
131     default:
132       abort();
133 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
134   }
135 }
136 
137 class Scratch {
138  public:
Scratch()139   Scratch() : buffer_(), size_(0) {}
140 
AssureSize(std::int32_t required_size)141   void AssureSize(std::int32_t required_size) {
142     if (size_ >= required_size) {
143       return;
144     }
145     buffer_.reset(new std::uint8_t[required_size]);
146     size_ = required_size;
147   }
148 
Clear()149   void Clear() {
150     buffer_.reset(nullptr);
151     size_ = 0;
152   }
153 
buffer()154   std::uint8_t* buffer() { return buffer_.get(); }
155 
156  private:
157   std::unique_ptr<std::uint8_t[]> buffer_;
158   std::int32_t size_;
159 };
160 
161 Scratch* global_scratch = nullptr;
162 
GetOrCreateGlobalScratch()163 Scratch* GetOrCreateGlobalScratch() {
164   if (global_scratch == nullptr) {
165     global_scratch = new Scratch();
166   }
167   return global_scratch;
168 }
169 
DestroyGlobalScratch()170 void DestroyGlobalScratch() {
171   delete global_scratch;
172   global_scratch = nullptr;
173 }
174 
175 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
176 
IsRowMajorOrVector(bool transpose,int stride,int rows,int cols)177 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
178   // Is it row major and nicely packed?
179   if (transpose && stride == cols) {
180     return true;
181   }
182 
183   // Is it a one row vector? (a vector is both row and column major)
184   if (rows == 1) {
185     return true;
186   }
187 
188   return false;
189 }
190 
IsColumnMajorOrVector(bool transpose,int stride,int rows,int cols)191 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
192   // Is it column major and nicely packed?
193   if (!transpose && stride == rows) {
194     return true;
195   }
196 
197   // Is it a one column vector? (a vector is both row and column major)
198   if (cols == 1) {
199     return true;
200   }
201 
202   return false;
203 }
204 
CanHandleMetaFastpath(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,int lda,int ldb,int ldc,BitDepthSetting depth_setting)205 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
206                            int m, int n, int k, int lda, int ldb, int ldc,
207                            BitDepthSetting depth_setting) {
208   // Meta fastpath only supports 8bit x 8bit and k up to 2048.
209   if (depth_setting != BitDepthSetting::A8B8 || k > 2048) {
210     return false;
211   }
212 
213   // The first operand needs to be a row major matrix or a vector.
214   if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
215     return false;
216   }
217 
218   // The second operand needs to be a column major matrix or a vector.
219   if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
220     return false;
221   }
222 
223   // The result can either be a row major matrix, a column major matrix or
224   // a vector.
225   if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
226     return true;
227   }
228 
229   if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
230     return true;
231   }
232 
233   return false;
234 }
235 
236 // Assure enough scratch memory is allocated and run the fast path gemm.
MetaGemmQuantized8Bit(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplicative_offset,std::int32_t shift,bool result_transpose,std::int32_t result_stride,std::uint8_t * result)237 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
238                            const std::uint8_t* rhs, int m, int n, int k,
239                            std::int32_t lhs_offset, std::int32_t rhs_offset,
240                            std::int32_t sum_offset,
241                            std::int32_t multiplicative_offset,
242                            std::int32_t shift, bool result_transpose,
243                            std::int32_t result_stride, std::uint8_t* result) {
244   Scratch* scratch = GetOrCreateGlobalScratch();
245   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
246     scratch->AssureSize(
247         meta::gemm_q8_scratch(m, n, k, context->max_num_threads()));
248     meta::multi_thread_gemm_q8(
249         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
250         lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset,
251         multiplicative_offset, shift, result);
252   } else {
253     scratch->AssureSize(
254         meta::gemm_q8_scratch(n, m, k, context->max_num_threads()));
255     meta::multi_thread_gemm_q8(
256         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
257         rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset,
258         multiplicative_offset, shift, result);
259   }
260 }
261 
262 // Assure enough scratch memory is allocated and run the 8bit to float fast
263 // path gemm.
MetaGemmFloat(GemmContext * context,const std::uint8_t * lhs,const std::uint8_t * rhs,int m,int n,int k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,bool result_transpose,std::int32_t result_stride,float * result)264 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
265                    const std::uint8_t* rhs, int m, int n, int k,
266                    std::int32_t lhs_offset, std::int32_t rhs_offset,
267                    float result_offset, bool result_transpose,
268                    std::int32_t result_stride, float* result) {
269   Scratch* scratch = GetOrCreateGlobalScratch();
270   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
271     scratch->AssureSize(
272         meta::gemm_f_scratch(m, n, k, context->max_num_threads()));
273     meta::multi_thread_gemm_f(
274         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
275         lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result);
276   } else {
277     scratch->AssureSize(
278         meta::gemm_f_scratch(n, m, k, context->max_num_threads()));
279     meta::multi_thread_gemm_f(
280         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
281         rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result);
282   }
283 }
284 
285 #endif
286 
287 }  // end anonymous namespace
288 
289 // Public interface entry points
290 
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc,BitDepthSetting bit_depth)291 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
292                      int m, int n, int k, const std::uint8_t* a,
293                      std::int32_t a_offset, int lda, const std::uint8_t* b,
294                      std::int32_t b_offset, int ldb, std::uint8_t* c,
295                      std::int32_t c_offset, std::int32_t c_mult_int,
296                      std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
297   AutoGlobalLock<EightBitIntGemmLockId> lock;
298   GemmContext* context = GetOrCreateGlobalContext();
299 
300 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
301   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
302                             ldb, ldc, bit_depth)) {
303     MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
304                           c_mult_int, c_shift, transpose_c, ldc, c);
305     return;
306   }
307 #endif
308 
309 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
310   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
311     EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
312                                     b_offset, ldb, c, c_offset, c_mult_int, \
313                                     c_shift, ldc, bit_depth);               \
314   }
315 
316   GEMMLOWP_HANDLE_CASE(false, false, false)
317   GEMMLOWP_HANDLE_CASE(false, false, true)
318   GEMMLOWP_HANDLE_CASE(false, true, false)
319   GEMMLOWP_HANDLE_CASE(false, true, true)
320   GEMMLOWP_HANDLE_CASE(true, false, false)
321   GEMMLOWP_HANDLE_CASE(true, false, true)
322   GEMMLOWP_HANDLE_CASE(true, true, false)
323   GEMMLOWP_HANDLE_CASE(true, true, true)
324 
325 #undef GEMMLOWP_HANDLE_CASE
326 }
327 
EightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,std::int32_t lda,const std::uint8_t * b,std::int32_t b_offset,std::int32_t ldb,float * c,float c_offset,std::int32_t ldc,BitDepthSetting bit_depth)328 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
329                      int m, int n, int k, const std::uint8_t* a,
330                      std::int32_t a_offset, std::int32_t lda,
331                      const std::uint8_t* b, std::int32_t b_offset,
332                      std::int32_t ldb, float* c, float c_offset,
333                      std::int32_t ldc, BitDepthSetting bit_depth) {
334   AutoGlobalLock<EightBitIntGemmLockId> lock;
335   GemmContext* context = GetOrCreateGlobalContext();
336 
337 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
338   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
339                             ldb, ldc, bit_depth)) {
340     MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
341                   transpose_c, ldc, c);
342     return;
343   }
344 #endif
345 
346   // TODO(maciekc): implement a float output stage, get rid of scratch memory.
347   Scratch* scratch = GetOrCreateGlobalScratch();
348   if (transpose_c) {
349     scratch->AssureSize(m * ldc * sizeof(std::int32_t));
350   } else {
351     scratch->AssureSize(n * ldc * sizeof(std::int32_t));
352   }
353   std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
354 
355 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
356   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
357     EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
358                                          b, b_offset, ldb, temp_c, ldc,      \
359                                          bit_depth);                         \
360   }
361 
362   GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
363   GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
364   GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
365   GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
366   GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
367   GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
368   GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
369   GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
370 
371 #undef GEMMLOWP_HANDLE_INT32_CASE
372 
373   if (transpose_c) {
374     // Row major.
375     for (int i = 0; i < m; ++i) {
376       float* dest_row = c + i * ldc;
377       std::int32_t* src_row = temp_c + i * ldc;
378       for (int j = 0; j < n; ++j) {
379         dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
380       }
381     }
382   } else {
383     // Column major.
384     for (int i = 0; i < n; ++i) {
385       float* dest_column = c + i * ldc;
386       std::int32_t* src_column = temp_c + i * ldc;
387       for (int j = 0; j < m; ++j) {
388         dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
389       }
390     }
391   }
392 }
393 
SetMaxNumThreads(int n)394 void SetMaxNumThreads(int n) {
395   AutoGlobalLock<EightBitIntGemmLockId> lock;
396   GemmContext* context = GetOrCreateGlobalContext();
397   context->set_max_num_threads(n);
398 }
399 
FreePersistentResources()400 void FreePersistentResources() {
401   AutoGlobalLock<EightBitIntGemmLockId> lock;
402   DestroyGlobalContext();
403   DestroyGlobalScratch();
404 }
405 
406 }  // namespace eight_bit_int_gemm
407 }  // namespace gemmlowp
408