1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
16 // extending the set of such functions from fixedpoint.h.
17
18 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
19 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
20
21 #include <algorithm>
22 #include <type_traits>
23 #include "../fixedpoint/fixedpoint.h"
24
25 namespace gemmlowp {
26
27 template <typename ScalarType, int ScalarCount>
28 struct RegisterType {
29 using Type = ScalarType;
30 };
31
Min(std::int32_t a,std::int32_t b)32 inline std::int32_t Min(std::int32_t a, std::int32_t b) {
33 return std::min(a, b);
34 }
35
Max(std::int32_t a,std::int32_t b)36 inline std::int32_t Max(std::int32_t a, std::int32_t b) {
37 return std::max(a, b);
38 }
39
MulAdd(std::int32_t lhs,std::int32_t rhs,std::int32_t * acc)40 inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
41 *acc += lhs * rhs;
42 }
43
44 template <typename tScalarType, int tScalarCount>
45 struct RegisterBuffer {
46 using ScalarType = tScalarType;
47 static constexpr int kScalarCount = tScalarCount;
48 using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
49 static_assert((kScalarCount & (kScalarCount - 1)) == 0,
50 "kScalarCount must be a power of two");
51 static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
52 static constexpr int kRegisterLanes =
53 sizeof(RegisterType) / sizeof(ScalarType);
54 static constexpr int kRegisterCount =
55 (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
56 sizeof(RegisterType);
57
58 RegisterType reg[kRegisterCount];
59 };
60
61 template <typename tScalarType, int tRows, int tCols>
62 struct RegisterBlock {
63 using ScalarType = tScalarType;
64 static constexpr int kRows = tRows;
65 static constexpr int kCols = tCols;
66 static constexpr int kScalarCount = kRows * kCols;
67 using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
68 using RegisterType = typename BufferType::RegisterType;
69 static constexpr int kRegisterCount = BufferType::kRegisterCount;
70 static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
71
72 BufferType buf;
73 };
74
75 template <typename RegisterBlockType>
76 struct RegisterBlockAddImpl {
RunRegisterBlockAddImpl77 static RegisterBlockType Run(const RegisterBlockType& lhs,
78 const RegisterBlockType& rhs) {
79 RegisterBlockType result;
80 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
81 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
82 }
83 return result;
84 }
85 };
86
87 template <typename RegisterBlockType>
RegisterBlockAdd(const RegisterBlockType & lhs,const RegisterBlockType & rhs)88 RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
89 const RegisterBlockType& rhs) {
90 return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
91 }
92
93 template <typename LhsType, typename RhsType>
94 struct ShouldFlipLhsRhs {
95 static constexpr bool kValue =
96 (LhsType::kScalarCount < RhsType::kScalarCount) ||
97 (LhsType::kScalarCount == RhsType::kScalarCount &&
98 (LhsType::kRows < RhsType::kRows));
99 };
100
101 template <typename LhsType, typename RhsType,
102 bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
103 struct FlipLhsRhs {
104 using FlippedLhsType = LhsType;
105 using FlippedRhsType = RhsType;
FlippedLhsFlipLhsRhs106 static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
107 const RhsType& rhs) {
108 return lhs;
109 }
FlippedRhsFlipLhsRhs110 static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
111 const RhsType& rhs) {
112 return rhs;
113 }
114 };
115
116 template <typename LhsType, typename RhsType>
117 struct FlipLhsRhs<LhsType, RhsType, true> {
118 using FlippedLhsType = RhsType;
119 using FlippedRhsType = LhsType;
120 static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
121 const RhsType& rhs) {
122 return rhs;
123 }
124 static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
125 const RhsType& rhs) {
126 return lhs;
127 }
128 };
129
130 template <typename Lhs, typename Rhs>
131 struct BroadcastBinaryOpShape {
132 static constexpr int kRows =
133 Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
134 static constexpr int kCols =
135 Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
136 };
137
138 template <typename Lhs, typename Rhs>
139 struct BroadcastBinaryOpRegisterBlock {
140 using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
141 using ScalarType = typename Lhs::ScalarType;
142 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
143 };
144
145 template <typename Lhs, typename Rhs>
146 struct BroadcastAddImpl {
147 using ResultBlockType =
148 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
149 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
150 ResultBlockType result;
151 static constexpr int Rows = ResultBlockType::kRows;
152 static constexpr int Cols = ResultBlockType::kCols;
153 static constexpr int LhsRows = Lhs::kRows;
154 static constexpr int LhsCols = Lhs::kCols;
155 static constexpr int RhsRows = Rhs::kRows;
156 static constexpr int RhsCols = Rhs::kCols;
157
158 static_assert(LhsRows == Rows || LhsRows == 1, "");
159 static_assert(RhsRows == Rows || RhsRows == 1, "");
160 static_assert(LhsCols == Cols || LhsCols == 1, "");
161 static_assert(RhsCols == Cols || RhsCols == 1, "");
162 static_assert(ResultBlockType::kRegisterLanes == 1,
163 "This path is only for scalar values");
164 static_assert(Lhs::kRegisterLanes == 1,
165 "This path is only for scalar values");
166 static_assert(Rhs::kRegisterLanes == 1,
167 "This path is only for scalar values");
168
169 for (int c = 0; c < Cols; c++) {
170 const int lhs_c = LhsCols == Cols ? c : 0;
171 const int rhs_c = RhsCols == Cols ? c : 0;
172 for (int r = 0; r < Rows; r++) {
173 const int lhs_r = LhsRows == Rows ? r : 0;
174 const int rhs_r = RhsRows == Rows ? r : 0;
175 result.buf.reg[r + c * Rows] =
176 Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
177 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
178 }
179 }
180 return result;
181 }
182 };
183
184 template <typename Lhs, typename Rhs>
185 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
186 const Lhs& lhs, const Rhs& rhs) {
187 using Flip = FlipLhsRhs<Lhs, Rhs>;
188 return BroadcastAddImpl<
189 typename Flip::FlippedLhsType,
190 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
191 Flip::FlippedRhs(lhs, rhs));
192 }
193
194 template <typename Lhs, typename Rhs>
195 struct BroadcastMulImpl {
196 using ResultBlockType =
197 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
198 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
199 ResultBlockType result;
200 static constexpr int Rows = ResultBlockType::kRows;
201 static constexpr int Cols = ResultBlockType::kCols;
202 static constexpr int LhsRows = Lhs::kRows;
203 static constexpr int LhsCols = Lhs::kCols;
204 static constexpr int RhsRows = Rhs::kRows;
205 static constexpr int RhsCols = Rhs::kCols;
206 static_assert(ResultBlockType::kRegisterLanes == 1,
207 "This path is only for scalar values");
208 static_assert(Lhs::kRegisterLanes == 1,
209 "This path is only for scalar values");
210 static_assert(Rhs::kRegisterLanes == 1,
211 "This path is only for scalar values");
212
213 static_assert(LhsRows == Rows || LhsRows == 1, "");
214 static_assert(RhsRows == Rows || RhsRows == 1, "");
215 static_assert(LhsCols == Cols || LhsCols == 1, "");
216 static_assert(RhsCols == Cols || RhsCols == 1, "");
217 for (int c = 0; c < Cols; c++) {
218 const int lhs_c = LhsCols == Cols ? c : 0;
219 const int rhs_c = RhsCols == Cols ? c : 0;
220 for (int r = 0; r < Rows; r++) {
221 const int lhs_r = LhsRows == Rows ? r : 0;
222 const int rhs_r = RhsRows == Rows ? r : 0;
223 result.buf.reg[r + c * Rows] =
224 Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
225 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
226 }
227 }
228 return result;
229 }
230 };
231
232 template <typename Lhs, typename Rhs>
233 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
234 const Lhs& lhs, const Rhs& rhs) {
235 using Flip = FlipLhsRhs<Lhs, Rhs>;
236 return BroadcastMulImpl<
237 typename Flip::FlippedLhsType,
238 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
239 Flip::FlippedRhs(lhs, rhs));
240 }
241
242 template <typename Lhs, typename Rhs, typename Acc>
243 struct BroadcastMulAddImpl {
244 static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
245 static constexpr int Rows = Acc::kRows;
246 static constexpr int Cols = Acc::kCols;
247 static constexpr int LhsRows = Lhs::kRows;
248 static constexpr int LhsCols = Lhs::kCols;
249 static constexpr int RhsRows = Rhs::kRows;
250 static constexpr int RhsCols = Rhs::kCols;
251 static_assert(Acc::kRegisterLanes == 1,
252 "This path is only for scalar values");
253 static_assert(Lhs::kRegisterLanes == 1,
254 "This path is only for scalar values");
255 static_assert(Rhs::kRegisterLanes == 1,
256 "This path is only for scalar values");
257
258 static_assert(LhsRows == Rows || LhsRows == 1, "");
259 static_assert(RhsRows == Rows || RhsRows == 1, "");
260 static_assert(LhsCols == Cols || LhsCols == 1, "");
261 static_assert(RhsCols == Cols || RhsCols == 1, "");
262 for (int c = 0; c < Cols; c++) {
263 const int lhs_c = LhsCols == Cols ? c : 0;
264 const int rhs_c = RhsCols == Cols ? c : 0;
265 for (int r = 0; r < Rows; r++) {
266 const int lhs_r = LhsRows == Rows ? r : 0;
267 const int rhs_r = RhsRows == Rows ? r : 0;
268 MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
269 rhs.buf.reg[rhs_r + rhs_c * RhsRows],
270 &acc->buf.reg[r + c * Rows]);
271 }
272 }
273 }
274 };
275
276 template <typename Lhs, typename Rhs, typename Acc>
277 void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
278 using Flip = FlipLhsRhs<Lhs, Rhs>;
279 BroadcastMulAddImpl<typename Flip::FlippedLhsType,
280 typename Flip::FlippedRhsType,
281 Acc>::Run(Flip::FlippedLhs(lhs, rhs),
282 Flip::FlippedRhs(lhs, rhs), acc);
283 }
284
285 template <typename RegisterBlockType, typename SrcObjectType>
286 struct LoadImpl {
287 static_assert(std::is_same<SrcObjectType, void>::value,
288 "This generic impl should never be hit");
289 };
290
291 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
292 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
293 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
294 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
295 using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
296 static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
297 RegisterBlockType result;
298 int i = 0;
299 for (int c = 0; c < Cols; c++) {
300 const ScalarType* src_ptr = src.data(row, col + c);
301 for (int r = 0; r < Rows; r++) {
302 result.buf.reg[i++] = *src_ptr++;
303 }
304 }
305 return result;
306 }
307 };
308
309 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
310 VectorShape Shape>
311 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
312 VectorMap<SrcScalarType, Shape>> {
313 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
314 using SrcObjectType = VectorMap<SrcScalarType, Shape>;
315 static RegisterBlockType Run(const SrcObjectType& src, int pos) {
316 static_assert(Shape == VectorShape::Col || Rows == 1, "");
317 static_assert(Shape == VectorShape::Row || Cols == 1, "");
318 RegisterBlockType result;
319 for (int i = 0; i < Rows * Cols; i++) {
320 result.buf.reg[i] = src(pos + i);
321 }
322 return result;
323 }
324 };
325
326 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
327 VectorShape Shape>
328 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
329 VectorDup<SrcScalarType, Shape>> {
330 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
331 using SrcObjectType = VectorDup<SrcScalarType, Shape>;
332 static RegisterBlockType Run(const SrcObjectType& src, int) {
333 static_assert(Shape == VectorShape::Col || Rows == 1, "");
334 static_assert(Shape == VectorShape::Row || Cols == 1, "");
335 RegisterBlockType result;
336 for (int i = 0; i < Rows * Cols; i++) {
337 result.buf.reg[i] = src(0);
338 }
339 return result;
340 }
341 };
342
343 template <typename RegisterBlockType, typename SrcObjectType>
344 RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
345 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
346 }
347
348 template <typename RegisterBlockType, typename SrcObjectType>
349 RegisterBlockType Load(const SrcObjectType& src, int pos) {
350 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
351 }
352
353 template <typename RegisterBlockType>
354 struct LoadContiguousImpl {
355 using ScalarType = typename RegisterBlockType::ScalarType;
356 static_assert(RegisterBlockType::kRegisterLanes == 1,
357 "This path is only for scalar values");
358 static RegisterBlockType Run(const ScalarType* src) {
359 RegisterBlockType result;
360 for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
361 result.buf.reg[i] = src[i];
362 }
363 return result;
364 }
365 };
366
367 template <typename RegisterBlockType>
368 RegisterBlockType LoadContiguous(
369 const typename RegisterBlockType::ScalarType* src) {
370 return LoadContiguousImpl<RegisterBlockType>::Run(src);
371 }
372
373 template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
374 struct LoadForBroadcastingShape {};
375
376 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
377 VectorShape Shape>
378 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
379 VectorMap<ScalarType, Shape>> {
380 static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
381 static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
382 };
383
384 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
385 VectorShape Shape>
386 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
387 VectorDup<ScalarType, Shape>> {
388 static constexpr int kRows = 1;
389 static constexpr int kCols = 1;
390 };
391
392 template <typename RegisterBlockType, typename SrcObjectType>
393 struct LoadForBroadcastingRegisterBlock {
394 using Shape =
395 LoadForBroadcastingShape<RegisterBlockType::kRows,
396 RegisterBlockType::kCols, SrcObjectType>;
397 using ScalarType = typename RegisterBlockType::ScalarType;
398 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
399 };
400
401 template <typename RegisterBlockType, typename SrcObjectType>
402 struct LoadForBroadcastingImpl {
403 static_assert(std::is_same<SrcObjectType, void>::value,
404 "This generic impl should never be hit");
405 };
406
407 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
408 VectorShape Shape>
409 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
410 VectorMap<SrcScalarType, Shape>> {
411 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
412 using SrcObjectType = VectorMap<SrcScalarType, Shape>;
413 using ResultBlockType =
414 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
415 SrcObjectType>::Type;
416 static_assert(ResultBlockType::kRegisterLanes == 1,
417 "This path is only for scalar values");
418 static ResultBlockType Run(const SrcObjectType& src, int pos) {
419 ResultBlockType result;
420 for (int c = 0; c < ResultBlockType::kCols; c++) {
421 for (int r = 0; r < ResultBlockType::kRows; r++) {
422 const int i = Shape == VectorShape::Col ? r : c;
423 result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
424 }
425 }
426 return result;
427 }
428 };
429
430 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
431 VectorShape Shape>
432 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
433 VectorDup<SrcScalarType, Shape>> {
434 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
435 using SrcObjectType = VectorDup<SrcScalarType, Shape>;
436 using ResultBlockType =
437 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
438 SrcObjectType>::Type;
439 static_assert(ResultBlockType::kRegisterLanes == 1,
440 "This path is only for scalar values");
441 static ResultBlockType Run(const SrcObjectType& src, int) {
442 ResultBlockType result;
443 for (int c = 0; c < ResultBlockType::kCols; c++) {
444 for (int r = 0; r < ResultBlockType::kRows; r++) {
445 result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
446 }
447 }
448 return result;
449 }
450 };
451
452 template <typename RegisterBlockType, typename SrcObjectType>
453 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
454 SrcObjectType>::Type
455 LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
456 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
457 src, row, col);
458 }
459
460 template <typename RegisterBlockType, typename SrcObjectType>
461 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
462 SrcObjectType>::Type
463 LoadForBroadcasting(const SrcObjectType& src, int pos) {
464 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
465 pos);
466 }
467
468 template <int ConstantValue, typename RegisterBlockType>
469 struct AddConstantImpl {
470 static void Run(RegisterBlockType* block) {
471 using RegisterType = typename RegisterBlockType::RegisterType;
472 const RegisterType dup = Dup<RegisterType>(ConstantValue);
473 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
474 block->buf.reg[i] = Add(block->buf.reg[i], dup);
475 }
476 }
477 };
478
479 template <typename RegisterBlockType>
480 struct AddConstantImpl<0, RegisterBlockType> {
481 static void Run(RegisterBlockType*) {
482 // This is a no-op.
483 }
484 };
485
486 template <int ConstantValue, typename RegisterBlockType>
487 void AddConstant(RegisterBlockType* block) {
488 AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
489 }
490
491 template <int N>
492 using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
493 template <int N>
494 using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
495 template <int N>
496 using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
497 template <int R, int C>
498 using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
499 template <int R, int C>
500 using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
501 template <int R, int C>
502 using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
503
504 } // end namespace gemmlowp
505
506 #if defined GEMMLOWP_NEON
507 #include "simd_wrappers_neon.h"
508 #elif defined GEMMLOWP_SSE4
509 #include "simd_wrappers_sse.h"
510 #elif defined GEMMLOWP_MSA
511 #include "simd_wrappers_msa.h"
512 #endif
513
514 #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
515