• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // multi_thread_gemv.h: Entry point to the multithreaded version of the
16 // generated (meta) gemv library.
17 
18 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_
19 #define GEMMLOWP_META_MULTI_THREAD_GEMV_H_
20 
21 #ifdef GEMMLOWP_NEON
22 
23 #include "legacy_multi_thread_common.h"
24 #include "legacy_operations_common.h"
25 #include "legacy_single_thread_gemm.h"
26 
27 namespace gemmlowp {
28 namespace meta {
29 namespace internal {
30 
31 class GemvQuantized8BitOperation : public Quantized8BitOperation {
32  public:
GemvQuantized8BitOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift)33   GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
34                              std::int32_t sum_offset, std::int32_t multiplier,
35                              std::int32_t shift)
36       : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier,
37                                shift) {}
38 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,std::int32_t result_stride)39   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
40                            const std::uint8_t* rhs, std::int32_t m,
41                            std::int32_t n, std::int32_t k, std::uint8_t* result,
42                            std::int32_t result_stride) const {
43     gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset,
44             multiplier, shift, result);
45   }
46 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)47   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
48                                        std::int32_t k) {
49     return 128 * 1024;
50   }
51 };
52 
53 class GemvFloatOperation : public FloatOperation {
54  public:
GemvFloatOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset)55   GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
56                      float result_offset)
57       : FloatOperation(lhs_offset, rhs_offset, result_offset) {}
58 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,std::int32_t result_stride)59   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
60                            const std::uint8_t* rhs, std::int32_t m,
61                            std::int32_t n, std::int32_t k, float* result,
62                            std::int32_t result_stride) const {
63     gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset,
64            result);
65   }
66 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)67   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
68                                        std::int32_t k) {
69     return 128 * 1024;
70   }
71 };
72 
73 class GemvInt32Operation : public Int32Operation {
74  public:
GemvInt32Operation(std::int32_t lhs_offset,std::int32_t rhs_offset)75   GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
76       : Int32Operation(lhs_offset, rhs_offset) {}
77 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t * result,std::int32_t result_stride)78   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
79                            const std::uint8_t* rhs, std::int32_t m,
80                            std::int32_t n, std::int32_t k, std::int32_t* result,
81                            std::int32_t result_stride) const {
82     gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result);
83   }
84 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)85   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
86                                        std::int32_t k) {
87     return 128 * 1024;
88   }
89 };
90 
91 }  // namespace internal
92 
gemv_q8_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)93 std::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
94                              std::int32_t max_threads) {
95   return internal::ResolveMaxThreads(max_threads) *
96          internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k);
97 }
98 
multi_thread_gemv_q8(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift,std::uint8_t * result)99 void multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
100                           std::uint8_t* scratch, const std::uint8_t* lhs,
101                           const std::uint8_t* rhs, std::int32_t n,
102                           std::int32_t k, std::int32_t lhs_offset,
103                           std::int32_t rhs_offset, std::int32_t sum_offset,
104                           std::int32_t multiplier, std::int32_t shift,
105                           std::uint8_t* result) {
106   max_threads = internal::ResolveMaxThreads(max_threads);
107   internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset,
108                                                  sum_offset, multiplier, shift);
109   if (max_threads == 1) {
110     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
111   } else {
112     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
113                                         n, k, result, n, operation);
114   }
115 }
116 
gemv_f_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)117 std::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
118                             std::int32_t max_threads) {
119   return internal::ResolveMaxThreads(max_threads) *
120          internal::GemvFloatOperation::ScratchPerThread(m, n, k);
121 }
122 
multi_thread_gemv_f(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)123 void multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
124                          std::uint8_t* scratch, const std::uint8_t* lhs,
125                          const std::uint8_t* rhs, std::int32_t n,
126                          std::int32_t k, std::int32_t lhs_offset,
127                          std::int32_t rhs_offset, float result_offset,
128                          float* result) {
129   max_threads = internal::ResolveMaxThreads(max_threads);
130   internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset);
131   if (max_threads == 1) {
132     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
133   } else {
134     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
135                                         n, k, result, n, operation);
136   }
137 }
138 
gemv_i32_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)139 std::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
140                               std::int32_t max_threads) {
141   return internal::ResolveMaxThreads(max_threads) *
142          internal::GemvInt32Operation::ScratchPerThread(m, n, k);
143 }
144 
multi_thread_gemv_i32(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)145 void multi_thread_gemv_i32(gemmlowp::WorkersPool* pool,
146                            std::int32_t max_threads, std::uint8_t* scratch,
147                            const std::uint8_t* lhs, const std::uint8_t* rhs,
148                            std::int32_t n, std::int32_t k,
149                            std::int32_t lhs_offset, std::int32_t rhs_offset,
150                            std::int32_t* result) {
151   max_threads = internal::ResolveMaxThreads(max_threads);
152   internal::GemvInt32Operation operation(lhs_offset, rhs_offset);
153   if (max_threads == 1) {
154     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
155   } else {
156     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
157                                         n, k, result, n, operation);
158   }
159 }
160 
161 }  // namespace meta
162 }  // namespace gemmlowp
163 
164 #else
165 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
166 #endif
167 
168 #endif  // GEMMLOWP_META_MULTI_THREAD_GEMV_H_
169