1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // multi_thread_gemv.h: Entry point to the multithreaded version of the
16 // generated (meta) gemv library.
17
18 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_
19 #define GEMMLOWP_META_MULTI_THREAD_GEMV_H_
20
21 #ifdef GEMMLOWP_NEON
22
23 #include "legacy_multi_thread_common.h"
24 #include "legacy_operations_common.h"
25 #include "legacy_single_thread_gemm.h"
26
27 namespace gemmlowp {
28 namespace meta {
29 namespace internal {
30
31 class GemvQuantized8BitOperation : public Quantized8BitOperation {
32 public:
GemvQuantized8BitOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift)33 GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
34 std::int32_t sum_offset, std::int32_t multiplier,
35 std::int32_t shift)
36 : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier,
37 shift) {}
38
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,std::int32_t result_stride)39 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
40 const std::uint8_t* rhs, std::int32_t m,
41 std::int32_t n, std::int32_t k, std::uint8_t* result,
42 std::int32_t result_stride) const {
43 gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset,
44 multiplier, shift, result);
45 }
46
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)47 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
48 std::int32_t k) {
49 return 128 * 1024;
50 }
51 };
52
53 class GemvFloatOperation : public FloatOperation {
54 public:
GemvFloatOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset)55 GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
56 float result_offset)
57 : FloatOperation(lhs_offset, rhs_offset, result_offset) {}
58
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,std::int32_t result_stride)59 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
60 const std::uint8_t* rhs, std::int32_t m,
61 std::int32_t n, std::int32_t k, float* result,
62 std::int32_t result_stride) const {
63 gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset,
64 result);
65 }
66
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)67 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
68 std::int32_t k) {
69 return 128 * 1024;
70 }
71 };
72
73 class GemvInt32Operation : public Int32Operation {
74 public:
GemvInt32Operation(std::int32_t lhs_offset,std::int32_t rhs_offset)75 GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
76 : Int32Operation(lhs_offset, rhs_offset) {}
77
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t * result,std::int32_t result_stride)78 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
79 const std::uint8_t* rhs, std::int32_t m,
80 std::int32_t n, std::int32_t k, std::int32_t* result,
81 std::int32_t result_stride) const {
82 gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result);
83 }
84
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)85 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
86 std::int32_t k) {
87 return 128 * 1024;
88 }
89 };
90
91 } // namespace internal
92
gemv_q8_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)93 std::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
94 std::int32_t max_threads) {
95 return internal::ResolveMaxThreads(max_threads) *
96 internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k);
97 }
98
multi_thread_gemv_q8(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift,std::uint8_t * result)99 void multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
100 std::uint8_t* scratch, const std::uint8_t* lhs,
101 const std::uint8_t* rhs, std::int32_t n,
102 std::int32_t k, std::int32_t lhs_offset,
103 std::int32_t rhs_offset, std::int32_t sum_offset,
104 std::int32_t multiplier, std::int32_t shift,
105 std::uint8_t* result) {
106 max_threads = internal::ResolveMaxThreads(max_threads);
107 internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset,
108 sum_offset, multiplier, shift);
109 if (max_threads == 1) {
110 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
111 } else {
112 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
113 n, k, result, n, operation);
114 }
115 }
116
gemv_f_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)117 std::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
118 std::int32_t max_threads) {
119 return internal::ResolveMaxThreads(max_threads) *
120 internal::GemvFloatOperation::ScratchPerThread(m, n, k);
121 }
122
multi_thread_gemv_f(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)123 void multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
124 std::uint8_t* scratch, const std::uint8_t* lhs,
125 const std::uint8_t* rhs, std::int32_t n,
126 std::int32_t k, std::int32_t lhs_offset,
127 std::int32_t rhs_offset, float result_offset,
128 float* result) {
129 max_threads = internal::ResolveMaxThreads(max_threads);
130 internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset);
131 if (max_threads == 1) {
132 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
133 } else {
134 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
135 n, k, result, n, operation);
136 }
137 }
138
gemv_i32_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)139 std::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
140 std::int32_t max_threads) {
141 return internal::ResolveMaxThreads(max_threads) *
142 internal::GemvInt32Operation::ScratchPerThread(m, n, k);
143 }
144
multi_thread_gemv_i32(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)145 void multi_thread_gemv_i32(gemmlowp::WorkersPool* pool,
146 std::int32_t max_threads, std::uint8_t* scratch,
147 const std::uint8_t* lhs, const std::uint8_t* rhs,
148 std::int32_t n, std::int32_t k,
149 std::int32_t lhs_offset, std::int32_t rhs_offset,
150 std::int32_t* result) {
151 max_threads = internal::ResolveMaxThreads(max_threads);
152 internal::GemvInt32Operation operation(lhs_offset, rhs_offset);
153 if (max_threads == 1) {
154 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
155 } else {
156 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
157 n, k, result, n, operation);
158 }
159 }
160
161 } // namespace meta
162 } // namespace gemmlowp
163
164 #else
165 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
166 #endif
167
168 #endif // GEMMLOWP_META_MULTI_THREAD_GEMV_H_
169