• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_KERNEL_X86_H_
17 #define RUY_RUY_KERNEL_X86_H_
18 
19 #include <cstdint>
20 #include <cstring>
21 
22 #include "ruy/kernel_common.h"
23 #include "ruy/mat.h"
24 #include "ruy/mul_params.h"
25 #include "ruy/opt_set.h"
26 #include "ruy/path.h"
27 #include "ruy/platform.h"
28 #include "ruy/tune.h"
29 
30 namespace ruy {
31 
32 #if RUY_PLATFORM_X86
33 
34 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma)
35 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx)
36 RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512)
37 
38 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
39 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
40 
41 template <typename DstScalar>
42 struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
43   static constexpr Path kPath = Path::kAvx512;
44   Tuning tuning = Tuning::kAuto;
45   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
46   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
47   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
48   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
49            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
50            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
51     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
52     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
53                          end_col, dst, &params);
54     if (dst->layout.cols == 1 &&
55         mul_params.channel_dimension() == ChannelDimension::kRow) {
56       Kernel8bitAvx512SingleCol(params);
57     } else {
58       Kernel8bitAvx512(params);
59     }
60   }
61 };
62 
63 void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
64 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
65 
66 template <>
67 struct Kernel<Path::kAvx512, float, float, float, float> {
68   static constexpr Path kPath = Path::kAvx512;
69   Tuning tuning = Tuning::kAuto;
70   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
71   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
72   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
73   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
74            const MulParams<float, float>& mul_params, int start_row,
75            int start_col, int end_row, int end_col, Mat<float>* dst) const {
76     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
77     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
78                           end_col, dst, &params);
79     if (dst->layout.cols == 1 &&
80         mul_params.channel_dimension() == ChannelDimension::kRow) {
81       KernelFloatAvx512SingleCol(params);
82     } else {
83       KernelFloatAvx512(params);
84     }
85   }
86 };
87 
88 void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
89 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
90 
91 template <typename DstScalar>
92 struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t,
93               DstScalar> {
94   static constexpr Path kPath = Path::kAvx2Fma;
95   Tuning tuning = Tuning::kAuto;
96   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
97   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
98   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
99   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
100            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
101            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
102     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
103     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
104                          end_col, dst, &params);
105     if (dst->layout.cols == 1 &&
106         mul_params.channel_dimension() == ChannelDimension::kRow) {
107       Kernel8bitAvx2SingleCol(params);
108     } else {
109       Kernel8bitAvx2(params);
110     }
111   }
112 };
113 
114 void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
115 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
116 
117 template <>
118 struct Kernel<Path::kAvx2Fma, float, float, float, float> {
119   static constexpr Path kPath = Path::kAvx2Fma;
120   Tuning tuning = Tuning::kAuto;
121   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
122   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
123   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
124   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
125            const MulParams<float, float>& mul_params, int start_row,
126            int start_col, int end_row, int end_col, Mat<float>* dst) const {
127     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
128     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
129                           end_col, dst, &params);
130     if (dst->layout.cols == 1 &&
131         mul_params.channel_dimension() == ChannelDimension::kRow) {
132       KernelFloatAvx2SingleCol(params);
133     } else {
134       KernelFloatAvx2(params);
135     }
136   }
137 };
138 
139 void KernelFloatAvx(const KernelParamsFloat<8, 8>& params);
140 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params);
141 
142 template <>
143 struct Kernel<Path::kAvx, float, float, float, float> {
144   static constexpr Path kPath = Path::kAvx;
145   Tuning tuning = Tuning::kAuto;
146   using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
147   using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
148   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
149   void Run(const PMat<float>& lhs, const PMat<float>& rhs,
150            const MulParams<float, float>& mul_params, int start_row,
151            int start_col, int end_row, int end_col, Mat<float>* dst) const {
152     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
153     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
154                           end_col, dst, &params);
155     if (dst->layout.cols == 1 &&
156         mul_params.channel_dimension() == ChannelDimension::kRow) {
157       KernelFloatAvxSingleCol(params);
158     } else {
159       KernelFloatAvx(params);
160     }
161   }
162 };
163 
164 void Kernel8bitAvx(const KernelParams8bit<8, 8>& params);
165 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params);
166 
167 template <typename DstScalar>
168 struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
169   static constexpr Path kPath = Path::kAvx2Fma;
170   Tuning tuning = Tuning::kAuto;
171   using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
172   using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
173   explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
174   void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
175            const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
176            int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
177     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
178     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
179                          end_col, dst, &params);
180     if (dst->layout.cols == 1 &&
181         mul_params.channel_dimension() == ChannelDimension::kRow) {
182       Kernel8bitAvxSingleCol(params);
183     } else {
184       Kernel8bitAvx(params);
185     }
186   }
187 };
188 
189 #endif  // RUY_PLATFORM_X86
190 }  // namespace ruy
191 
192 #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
193 
194 #include <immintrin.h>  // IWYU pragma: keep
195 
196 namespace ruy {
197 namespace {
198 namespace intrin_utils {
199 
200 // Defined as a template so clang won't detect it as an uneeded
201 // definition.
202 template <Path path>
203 inline float mm256_get1_ps(const __m256 a, int i) {
204   __m256i ai = _mm256_castps_si256(a);
205   int float_val_as_int;
206   switch (i) {
207     case 0:
208       float_val_as_int = _mm256_extract_epi32(ai, 0);
209       break;
210     case 1:
211       float_val_as_int = _mm256_extract_epi32(ai, 1);
212       break;
213     case 2:
214       float_val_as_int = _mm256_extract_epi32(ai, 2);
215       break;
216     case 3:
217       float_val_as_int = _mm256_extract_epi32(ai, 3);
218       break;
219     case 4:
220       float_val_as_int = _mm256_extract_epi32(ai, 4);
221       break;
222     case 5:
223       float_val_as_int = _mm256_extract_epi32(ai, 5);
224       break;
225     case 6:
226       float_val_as_int = _mm256_extract_epi32(ai, 6);
227       break;
228     case 7:
229       float_val_as_int = _mm256_extract_epi32(ai, 7);
230       break;
231     default:
232       RUY_DCHECK_LT(i, 8);
233       return .0f;
234   }
235   float float_val;
236   std::memcpy(&float_val, &float_val_as_int, sizeof(float_val));
237   return float_val;
238 }
239 
240 // Defined as a template so clang won't detect it as an uneeded
241 // definition.
242 template <Path path>
243 inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
244   for (int i = 0; i < residual_rows; ++i) {
245     dst[i] = intrin_utils::mm256_get1_ps<path>(v, i);
246   }
247 }
248 
249 template <Path path>
250 inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) {
251   // Specializations added for AVX and AVX2FMA paths in their respective kernel
252   // files.
253   RUY_DCHECK(false);
254   return _mm256_set1_ps(0);
255 }
256 
257 template <Path path>
258 inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) {
259   // Specializations added for AVX and AVX2FMA paths in their respective kernel
260   // files.
261   RUY_DCHECK(false);
262   return _mm256_set1_epi32(0);
263 }
264 
265 // Polyfill for _mm_storeu_si16(dst, v).
266 template <Path path>
267 inline void mm_storeu_si16(void* dst, __m128i v) {
268 #if (defined __clang__) || (defined _MSC_VER)
269   _mm_storeu_si16(dst, v);
270 #else
271   // GCC 9 lacks support for __mm_storeu_si16.
272   *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0);
273 #endif
274 }
275 
276 // Polyfill for _mm_storeu_si32(dst, v).
277 template <Path path>
278 inline void mm_storeu_si32(void* dst, __m128i v) {
279 #if (defined __clang__) || (defined _MSC_VER)
280   _mm_storeu_si32(dst, v);
281 #else
282   // GCC 9 lacks support for __mm_storeu_si32.
283   *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0);
284 #endif
285 }
286 
287 // Polyfill for _mm_loadu_si32(src).
288 template <Path path>
289 inline __m128i mm_loadu_si32(const void* src) {
290 #if (defined __clang__) || (defined _MSC_VER)
291   return _mm_loadu_si32(src);
292 #else
293   // GCC 9 lacks support for _mm_loadu_si32.
294   __m128i res;
295   asm("movss %[src], %[res]"
296       : [res] "=x"(res)
297       : [src] "m"(*static_cast<const int*>(src)));
298   return res;
299 #endif
300 }
301 
302 template <Path path>
303 inline __m128i mm256_extracti128_si256(const __m256i&, const int) {
304   RUY_DCHECK(false);
305   return _mm_setzero_si128();
306 }
307 
308 template <Path path>
309 inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows,
310                                          const __m256i v) {
311   // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
312   const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
313   __m256i shuffled_v;
314   if (residual_rows > 1) {
315     // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4
316     // in each 128-bit lane.
317     shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm);
318   }
319   switch (residual_rows) {
320     case 0:
321       break;
322     case 1:
323       dst[0] = _mm256_extract_epi8(v, 0);
324       break;
325     case 2:
326       mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
327       break;
328     case 3: {
329       __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0);
330       mm_storeu_si16<path>(dst, trailing_packed);
331       dst[2] = _mm_extract_epi8(trailing_packed, 2);
332       break;
333     }
334     case 4:
335       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
336       break;
337     case 5:
338       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
339       dst[4] = _mm256_extract_epi8(shuffled_v, 16);
340       break;
341     case 6:
342       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
343       mm_storeu_si16<path>(dst + 4,
344                            mm256_extracti128_si256<path>(shuffled_v, 1));
345       break;
346     case 7: {
347       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
348       __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
349       mm_storeu_si16<path>(dst + 4, trailing_packed);
350       dst[6] = _mm_extract_epi8(trailing_packed, 2);
351       break;
352     }
353     case 8:
354       mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
355       mm_storeu_si32<path>(dst + 4,
356                            mm256_extracti128_si256<path>(shuffled_v, 1));
357       break;
358     default:
359       RUY_DCHECK_LE(residual_rows, 8);
360       break;
361   }
362 }
363 
364 template <Path path>
365 inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) {
366   // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
367   const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
368   const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
369   mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
370   mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
371 }
372 
373 template <Path path>
374 inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows,
375                                          const __m256i v) {
376   intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
377       reinterpret_cast<std::uint8_t*>(dst), residual_rows, v);
378 }
379 
380 template <Path path>
381 inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) {
382   // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
383   const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
384   const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
385   mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
386   mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
387 }
388 
389 template <Path path>
390 inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows,
391                                           const __m256i v) {
392   // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
393   // truncating each 16-bit integer.
394   const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
395   __m256i shuffled_v;
396   __m128i shuffled_v_low;
397   if (residual_rows > 1) {
398     shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
399     shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0);
400   } else {
401     shuffled_v_low = mm256_extracti128_si256<path>(v, 0);
402   }
403   switch (residual_rows) {
404     case 0:
405       break;
406     case 1:
407       mm_storeu_si16<path>(dst, shuffled_v_low);
408       break;
409     case 2:
410       mm_storeu_si32<path>(dst, shuffled_v_low);
411       break;
412     case 3: {
413       mm_storeu_si32<path>(dst, shuffled_v_low);
414       dst[2] = _mm_extract_epi16(shuffled_v_low, 2);
415       break;
416     }
417     case 4:
418       _mm_storeu_si64(dst, shuffled_v_low);
419       break;
420     case 5:
421       _mm_storeu_si64(dst, shuffled_v_low);
422       dst[4] = _mm256_extract_epi16(shuffled_v, 8);
423       break;
424     case 6:
425       _mm_storeu_si64(dst, shuffled_v_low);
426       mm_storeu_si32<path>(dst + 4,
427                            mm256_extracti128_si256<path>(shuffled_v, 1));
428       break;
429     case 7: {
430       _mm_storeu_si64(dst, shuffled_v_low);
431       __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
432       mm_storeu_si32<path>(dst + 4, trailing_packed);
433       dst[6] = _mm_extract_epi16(trailing_packed, 2);
434       break;
435     }
436     case 8:
437       _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
438       _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
439       break;
440     default:
441       RUY_DCHECK_LE(residual_rows, 8);
442       break;
443   }
444 }
445 
446 template <Path path>
447 inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) {
448   // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
449   // truncating each 16-bit integer.
450   const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
451   const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
452   _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
453   _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
454 }
455 
456 template <Path path>
457 inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows,
458                                  const __m256i v) {
459   const __m128i v_low = mm256_extracti128_si256<path>(v, 0);
460   switch (residual_rows) {
461     case 0:
462       break;
463     case 1:
464       mm_storeu_si32<path>(dst, v_low);
465       break;
466     case 2:
467       _mm_storeu_si64(dst, v_low);
468       break;
469     case 3: {
470       __m128i trailing_packed = v_low;
471       _mm_storeu_si64(dst, trailing_packed);
472       dst[2] = _mm_extract_epi32(trailing_packed, 2);
473       break;
474     }
475     case 4:
476       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
477       break;
478     case 5:
479       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
480       dst[4] = _mm256_extract_epi32(v, 4);
481       break;
482     case 6:
483       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
484       _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1));
485       break;
486     case 7: {
487       _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
488       __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1);
489       _mm_storeu_si64(dst + 4, trailing_packed);
490       dst[6] = _mm_extract_epi32(trailing_packed, 2);
491       break;
492     }
493     case 8:
494       _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
495       break;
496     default:
497       RUY_DCHECK_LE(residual_rows, 8);
498       break;
499   }
500 }
501 
502 template <Path path>
503 inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) {
504   _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
505 }
506 
507 // Transpose a 8x8 matrix of floats.
508 template <Path path>
509 void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3,
510                            __m256* v4, __m256* v5, __m256* v6, __m256* v7) {
511   __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1);
512   __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1);
513   __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3);
514   __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3);
515   __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5);
516   __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5);
517   __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7);
518   __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7);
519   __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0));
520   __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2));
521   __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0));
522   __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2));
523   __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0));
524   __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2));
525   __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0));
526   __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2));
527   *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20);
528   *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20);
529   *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20);
530   *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20);
531   *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31);
532   *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31);
533   *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31);
534   *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31);
535 }
536 
537 // Transpose a 8x8 matrix of int32's.
538 template <Path path>
539 void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2,
540                               __m256i* v3, __m256i* v4, __m256i* v5,
541                               __m256i* v6, __m256i* v7) {
542   mm256_transpose8x8_ps<path>(
543       reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1),
544       reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3),
545       reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5),
546       reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7));
547 }
548 
549 }  // namespace intrin_utils
550 }  // namespace
551 
552 template <Path path>
553 inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
554   // As parameters are defined, we need to scale by sizeof(float).
555   const std::int64_t lhs_stride = params.lhs_stride >> 2;
556   const std::int64_t dst_stride = params.dst_stride >> 2;
557   const std::int64_t rhs_stride = params.rhs_stride >> 2;
558   //
559   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
560   // AVX2 float block size = 8.
561   const int end_row = std::min(params.dst_rows, params.last_row + 8);
562   const int end_col = std::min(params.dst_cols, params.last_col + 8);
563   //
564   const float* adj_rhs_col_ptr =
565       params.rhs_base_ptr - params.start_col * rhs_stride;
566   float* adj_dst_col_ptr =
567       params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
568   const float* adj_lhs_col_ptr =
569       params.lhs_base_ptr - params.start_row * lhs_stride;
570   const float* bias_ptr = params.bias;
571 
572   const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
573   const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
574   const bool channel_dimension_is_col =
575       params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
576 
577   int col = params.start_col;
578   // Loop through cols by float block size, leaving incomplete remainder
579   for (; col <= end_col - 8; col += 8) {
580     __m256 accum_data_v[8];
581 
582     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
583     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
584 
585     for (int row = params.start_row; row < end_row; row += 8) {
586       const int residual_rows = std::min(end_row - row, 8);
587 
588       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
589       float* dst_ptr = dst_col_ptr + row;
590 
591       // Initialize with bias.
592       if (channel_dimension_is_col) {
593         const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
594         for (int j = 0; j < 8; ++j) {
595           accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
596         }
597       } else {
598         const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
599         const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
600 
601         for (int j = 0; j < 8; ++j) {
602           accum_data_v[j] = initial_accum_data;
603         }
604       }
605 
606       const float* lhs_ptr = lhs_col_ptr;
607       const float* rhs_ptr = rhs_col_ptr;
608       for (int d = 0; d < params.depth; ++d) {
609         const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
610         // Load 8 RHS values, then use permute instructions to broadcast each
611         // value to a register. _mm256_permute2f128_ps is slow on AMD.
612         __m256 rhs0_3 =
613             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
614         __m256 rhs4_7 =
615             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
616 
617         const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
618         accum_data_v[0] = intrin_utils::MulAdd<path>(
619             lhs_data, dup_rhs_element_0, accum_data_v[0]);
620 
621         const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
622         accum_data_v[1] = intrin_utils::MulAdd<path>(
623             lhs_data, dup_rhs_element_1, accum_data_v[1]);
624 
625         const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
626         accum_data_v[2] = intrin_utils::MulAdd<path>(
627             lhs_data, dup_rhs_element_2, accum_data_v[2]);
628 
629         const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
630         accum_data_v[3] = intrin_utils::MulAdd<path>(
631             lhs_data, dup_rhs_element_3, accum_data_v[3]);
632 
633         const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
634         accum_data_v[4] = intrin_utils::MulAdd<path>(
635             lhs_data, dup_rhs_element_4, accum_data_v[4]);
636 
637         const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
638         accum_data_v[5] = intrin_utils::MulAdd<path>(
639             lhs_data, dup_rhs_element_5, accum_data_v[5]);
640 
641         const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
642         accum_data_v[6] = intrin_utils::MulAdd<path>(
643             lhs_data, dup_rhs_element_6, accum_data_v[6]);
644 
645         const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
646         accum_data_v[7] = intrin_utils::MulAdd<path>(
647             lhs_data, dup_rhs_element_7, accum_data_v[7]);
648 
649         lhs_ptr += 8;
650         rhs_ptr += 8;
651       }
652 
653       if (residual_rows == 8) {
654         for (int j = 0; j < 8; ++j) {
655           float* block_ptr = dst_ptr + j * dst_stride;
656           accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
657           accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
658           _mm256_storeu_ps(block_ptr, accum_data_v[j]);
659         }
660       } else {
661         for (int j = 0; j < 8; ++j) {
662           float* block_ptr = dst_ptr + j * dst_stride;
663           accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
664           accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
665           intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
666                                                 accum_data_v[j]);
667         }
668       }
669     }  // End row-block loop.
670   }    // End col-block loop.
671 
672   if (col < end_col) {
673     // Remaining cols in [0, float block size).
674     RUY_DCHECK_GE(end_col - col, 0);
675     RUY_DCHECK_LT(end_col - col, 8);
676 
677     __m256 accum_data_v[8];
678 
679     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
680     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
681     const int residual_cols = std::min(end_col - col, 8);
682 
683     for (int row = params.start_row; row < end_row; row += 8) {
684       const int residual_rows = std::min(end_row - row, 8);
685 
686       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
687       float* dst_ptr = dst_col_ptr + row;
688 
689       // Initialize with bias.
690       if (channel_dimension_is_col) {
691         const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
692         for (int j = 0; j < 8; ++j) {
693           accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
694         }
695       } else {
696         const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
697         const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
698 
699         for (int j = 0; j < 8; ++j) {
700           accum_data_v[j] = initial_accum_data;
701         }
702       }
703 
704       const float* lhs_ptr = lhs_col_ptr;
705       const float* rhs_ptr = rhs_col_ptr;
706       for (int d = 0; d < params.depth; ++d) {
707         const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
708 
709         __m256 rhs0_3 =
710             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
711         __m256 rhs4_7 =
712             _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
713 
714         const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
715         accum_data_v[0] = intrin_utils::MulAdd<path>(
716             lhs_data, dup_rhs_element_0, accum_data_v[0]);
717 
718         const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
719         accum_data_v[1] = intrin_utils::MulAdd<path>(
720             lhs_data, dup_rhs_element_1, accum_data_v[1]);
721 
722         const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
723         accum_data_v[2] = intrin_utils::MulAdd<path>(
724             lhs_data, dup_rhs_element_2, accum_data_v[2]);
725 
726         const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
727         accum_data_v[3] = intrin_utils::MulAdd<path>(
728             lhs_data, dup_rhs_element_3, accum_data_v[3]);
729 
730         const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
731         accum_data_v[4] = intrin_utils::MulAdd<path>(
732             lhs_data, dup_rhs_element_4, accum_data_v[4]);
733 
734         const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
735         accum_data_v[5] = intrin_utils::MulAdd<path>(
736             lhs_data, dup_rhs_element_5, accum_data_v[5]);
737 
738         const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
739         accum_data_v[6] = intrin_utils::MulAdd<path>(
740             lhs_data, dup_rhs_element_6, accum_data_v[6]);
741 
742         const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
743         accum_data_v[7] = intrin_utils::MulAdd<path>(
744             lhs_data, dup_rhs_element_7, accum_data_v[7]);
745 
746         lhs_ptr += 8;
747         rhs_ptr += 8;
748       }
749 
750       for (int j = 0; j < residual_cols; ++j) {
751         float* block_ptr = dst_ptr + j * dst_stride;
752         accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
753         accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
754         intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
755                                               accum_data_v[j]);
756       }
757     }  // End row-block loop.
758   }    // End col-block terminal conditional.
759 }
760 
761 template <Path path>
762 inline void KernelFloatAvxCommonSingleCol(
763     const KernelParamsFloat<8, 8>& params) {
764   RUY_DCHECK_EQ(params.dst_cols, 1);
765   RUY_DCHECK_EQ(params.last_col, 0);
766   RUY_DCHECK_EQ(params.start_col, 0);
767 
768   // As parameters are defined, we need to scale by sizeof(float).
769   const std::int64_t lhs_stride = params.lhs_stride >> 2;
770   //
771   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
772   // AVX2 float block size = 8.
773   const int end_row = std::min(params.dst_rows, params.last_row + 8);
774 
775   float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
776   const float* adj_lhs_col_ptr =
777       params.lhs_base_ptr - params.start_row * lhs_stride;
778   const float* bias_col_ptr = params.bias;
779 
780   const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
781   const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
782 
783   __m256 accum_data_v;
784 
785   const float* rhs_col_ptr = params.rhs_base_ptr;
786   float* dst_col_ptr = adj_dst_col_ptr;
787 
788   int row = params.start_row;
789   for (; row <= end_row - 8; row += 8) {
790     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
791     float* dst_ptr = dst_col_ptr + row;
792     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
793 
794     // Initialize with bias.
795     accum_data_v = _mm256_loadu_ps(bias_ptr);
796 
797     const float* lhs_ptr = lhs_col_ptr;
798     const float* rhs_ptr = rhs_col_ptr;
799     int d = 0;
800     for (; d <= params.depth - 4; d += 4) {
801       const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
802       const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
803       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0,
804                                                 accum_data_v);
805       const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
806       const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
807       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1,
808                                                 accum_data_v);
809 
810       const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
811       const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
812       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2,
813                                                 accum_data_v);
814       const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
815       const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
816       accum_data_v = intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3,
817                                                 accum_data_v);
818       lhs_ptr += 32;  // Loaded 8 * 4 floats.
819       rhs_ptr += 32;
820     }
821     for (; d < params.depth; ++d) {
822       const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
823       const float* rhs_data = rhs_ptr;
824 
825       const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
826       accum_data_v =
827           intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
828       lhs_ptr += 8;
829       rhs_ptr += 8;
830     }
831 
832     accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
833     accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
834     _mm256_storeu_ps(dst_ptr, accum_data_v);
835   }  // End row-block loop.
836 
837   if (row < end_row) {
838     const int residual_rows = end_row - row;
839     RUY_CHECK_GE(residual_rows, 1);
840     RUY_CHECK_LT(residual_rows, 8);
841 
842     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
843     float* dst_ptr = dst_col_ptr + row;
844     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
845 
846     // Initialize with bias.
847     accum_data_v = _mm256_loadu_ps(bias_ptr);
848 
849     const float* lhs_ptr = lhs_col_ptr;
850     const float* rhs_ptr = rhs_col_ptr;
851     for (int d = 0; d < params.depth; ++d) {
852       const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
853       const float* rhs_data = rhs_ptr;
854 
855       const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
856       accum_data_v =
857           intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
858       lhs_ptr += 8;
859       rhs_ptr += 8;
860     }
861 
862     accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
863     accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
864     intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v);
865   }  // End handling of residual rows.
866 }
867 }  // namespace ruy
868 #endif  //  (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
869 
870 #endif  // RUY_RUY_KERNEL_X86_H_
871