• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <cstring>
19 
20 #include "ruy/check_macros.h"
21 #include "ruy/kernel_common.h"
22 #include "ruy/kernel_x86.h"
23 #include "ruy/opt_set.h"
24 #include "ruy/platform.h"
25 #include "ruy/profiler/instrumentation.h"
26 
27 #if RUY_PLATFORM_AVX && RUY_OPT(ASM)
28 #include <immintrin.h>  // IWYU pragma: keep
29 #endif
30 
31 namespace ruy {
32 
33 #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
34 
Kernel8bitAvx(const KernelParams8bit<8,8> &)35 void Kernel8bitAvx(const KernelParams8bit<8, 8>&) {
36   // CPU-ID-based checks should disable the path that would reach this point.
37   RUY_DCHECK(false);
38 }
39 
Kernel8bitAvxSingleCol(const KernelParams8bit<8,8> &)40 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) {
41   // CPU-ID-based checks should disable the path that would reach this point.
42   RUY_DCHECK(false);
43 }
44 
KernelFloatAvx(const KernelParamsFloat<8,8> &)45 void KernelFloatAvx(const KernelParamsFloat<8, 8>&) {
46   // CPU-ID-based checks should disable the path that would reach this point.
47   RUY_DCHECK(false);
48 }
49 
KernelFloatAvxSingleCol(const KernelParamsFloat<8,8> &)50 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) {
51   // CPU-ID-based checks should disable the path that would reach this point.
52   RUY_DCHECK(false);
53 }
54 
55 #else  // RUY_PLATFORM_AVX && RUY_OPT(ASM)
56 
57 static constexpr int kAvx8bitBlockSize = 8;
58 static constexpr int kAvx8bitInnerSize = 4;
59 
60 namespace {
61 namespace intrin_utils {
62 
63 template <>
64 inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a,
65                                               const __m256i& b) {
66   __m128i a_lo = _mm256_extractf128_si256(a, 0);
67   __m128i a_hi = _mm256_extractf128_si256(a, 1);
68   __m128i b_lo = _mm256_extractf128_si256(b, 0);
69   __m128i b_hi = _mm256_extractf128_si256(b, 1);
70   __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo);
71   __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi);
72   return _mm256_set_m128i(dst_hi, dst_lo);
73 }
74 
75 template <>
76 inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a,
77                                                    const int imm) {
78   switch (imm) {
79     case 0:
80       return _mm256_extractf128_si256(a, 0);
81     case 1:
82       return _mm256_extractf128_si256(a, 1);
83     default:
84       RUY_DCHECK_LT(imm, 2);
85       return _mm_setzero_si128();
86   }
87 }
88 
89 template <Path path>
90 inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
91   // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
92   __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
93   return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
94 }
95 
96 template <Path path>
97 inline __m256i mm256_cvtepi32_epi64(const __m128i& a) {
98   // sign extend the 32-bit values in the lower 64 bits of a.
99   __m128i lo = _mm_cvtepi32_epi64(a);
100   __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128()));
101   return _mm256_set_m128i(hi, lo);
102 }
103 
104 inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
105                                  const int imm) {
106   __m128i tmp = _mm_setzero_si128();
107   if (!(imm & 8)) {
108     switch (imm & 3) {
109       case 0:
110         return _mm256_extractf128_si256(a, 0);
111       case 1:
112         return _mm256_extractf128_si256(a, 1);
113       case 2:
114         return _mm256_extractf128_si256(b, 0);
115       case 3:
116         return _mm256_extractf128_si256(b, 1);
117     }
118   }
119   return tmp;
120 }
121 
122 template <Path path>
123 inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
124                                         const int imm) {
125   const int lo_imm = imm & 15;
126   __m128i lo = mm_permute_helper(a, b, lo_imm);
127   const int hi_imm = (imm >> 4) & 15;
128   __m128i hi = mm_permute_helper(a, b, hi_imm);
129   return _mm256_set_m128i(hi, lo);
130 }
131 
132 template <Path path>
133 inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) {
134   __m128i a_lo = _mm256_extractf128_si256(a, 0);
135   __m128i a_hi = _mm256_extractf128_si256(a, 1);
136   __m128i b_lo = _mm256_extractf128_si256(b, 0);
137   __m128i b_hi = _mm256_extractf128_si256(b, 1);
138   __m128i lo = _mm_max_epi32(a_lo, b_lo);
139   __m128i hi = _mm_max_epi32(a_hi, b_hi);
140   return _mm256_set_m128i(hi, lo);
141 }
142 
143 template <Path path>
144 inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) {
145   __m128i a_lo = _mm256_extractf128_si256(a, 0);
146   __m128i a_hi = _mm256_extractf128_si256(a, 1);
147   __m128i b_lo = _mm256_extractf128_si256(b, 0);
148   __m128i b_hi = _mm256_extractf128_si256(b, 1);
149   __m128i lo = _mm_min_epi32(a_lo, b_lo);
150   __m128i hi = _mm_min_epi32(a_hi, b_hi);
151   return _mm256_set_m128i(hi, lo);
152 }
153 
154 template <Path path>
155 inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
156   __m128i a_lo = _mm256_extractf128_si256(a, 0);
157   __m128i a_hi = _mm256_extractf128_si256(a, 1);
158   __m128i b_lo = _mm256_extractf128_si256(b, 0);
159   __m128i b_hi = _mm256_extractf128_si256(b, 1);
160   __m128i lo = _mm_add_epi32(a_lo, b_lo);
161   __m128i hi = _mm_add_epi32(a_hi, b_hi);
162   return _mm256_set_m128i(hi, lo);
163 }
164 
165 template <Path path>
166 inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) {
167   __m128i a_lo = _mm256_extractf128_si256(a, 0);
168   __m128i a_hi = _mm256_extractf128_si256(a, 1);
169   __m128i b_lo = _mm256_extractf128_si256(b, 0);
170   __m128i b_hi = _mm256_extractf128_si256(b, 1);
171   __m128i lo = _mm_add_epi64(a_lo, b_lo);
172   __m128i hi = _mm_add_epi64(a_hi, b_hi);
173   return _mm256_set_m128i(hi, lo);
174 }
175 
176 template <Path path>
177 inline __m256i mm256_slli_epi64(const __m256i& a, int imm) {
178   __m128i a_lo = _mm256_extractf128_si256(a, 0);
179   __m128i a_hi = _mm256_extractf128_si256(a, 1);
180   __m128i lo = _mm_slli_epi64(a_lo, imm);
181   __m128i hi = _mm_slli_epi64(a_hi, imm);
182   return _mm256_set_m128i(hi, lo);
183 }
184 
185 template <Path path>
186 inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) {
187   __m128i a_lo = _mm256_extractf128_si256(a, 0);
188   __m128i a_hi = _mm256_extractf128_si256(a, 1);
189   __m128i b_lo = _mm256_extractf128_si256(b, 0);
190   __m128i b_hi = _mm256_extractf128_si256(b, 1);
191   __m128i lo = _mm_mullo_epi32(a_lo, b_lo);
192   __m128i hi = _mm_mullo_epi32(a_hi, b_hi);
193   return _mm256_set_m128i(hi, lo);
194 }
195 
196 // Defined as a macro since `imm` must be an immediate.
197 #define BlendM128_epi32(a, b, imm) \
198   _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm))
199 
200 // Defined as a macro since `imm` must be an immediate.
201 #define BlendM128_epi64(a, b, imm) \
202   _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm))
203 
204 // Defined as a macro since `imm` must be an immediate.
205 #define mm256_blend_epi32(ans, a, b, imm)              \
206   __m128i a_lo = _mm256_extractf128_si256(a, 0);       \
207   __m128i a_hi = _mm256_extractf128_si256(a, 1);       \
208   __m128i b_lo = _mm256_extractf128_si256(b, 0);       \
209   __m128i b_hi = _mm256_extractf128_si256(b, 1);       \
210   __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \
211   __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4);  \
212   ans = _mm256_set_m128i(hi, lo);
213 
214 #define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm)   \
215   a_lo = _mm256_extractf128_si256(a, 0);               \
216   a_hi = _mm256_extractf128_si256(a, 1);               \
217   ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \
218                          _mm_shuffle_epi32(a_lo, imm));
219 
220 template <Path path>
221 inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
222   __m128i a_lo = _mm256_extractf128_si256(a, 0);
223   __m128i a_hi = _mm256_extractf128_si256(a, 1);
224   __m128i b_lo = _mm256_extractf128_si256(b, 0);
225   __m128i b_hi = _mm256_extractf128_si256(b, 1);
226   __m128i lo = _mm_madd_epi16(a_lo, b_lo);
227   __m128i hi = _mm_madd_epi16(a_hi, b_hi);
228   return _mm256_set_m128i(hi, lo);
229 }
230 
231 inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) {
232   // shift both elements of a by lower 64bits of b.
233   __m128i res_lo = _mm_srl_epi64(a, b);
234   // shift both elements of a by upper 64bits of b.
235   __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
236   __m128i res_hi = _mm_srl_epi64(a, hi_count);
237   // Take the lower 64 bits of res_lo and upper 64 bits of res hi
238   // 1. Swap the upper and lower 64 bits of res_hi
239   __m128i tmp_hi =
240       _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
241   // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
242   return _mm_unpacklo_epi64(res_lo, tmp_hi);
243 }
244 
245 template <Path path>
246 inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) {
247   __m128i a_lo = _mm256_extractf128_si256(a, 0);
248   __m128i a_hi = _mm256_extractf128_si256(a, 1);
249   __m128i b_lo = _mm256_extractf128_si256(b, 0);
250   __m128i b_hi = _mm256_extractf128_si256(b, 1);
251   __m128i lo = mm_srlv_epi64(a_lo, b_lo);
252   __m128i hi = mm_srlv_epi64(a_hi, b_hi);
253   return _mm256_set_m128i(hi, lo);
254 }
255 
256 template <Path path>
257 inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) {
258   // shift both elements of a by lower 64bits of b.
259   __m128i res_lo = _mm_sll_epi64(a, b);
260   // shift both elements of a by upper 64bits of b.
261   __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
262   __m128i res_hi = _mm_sll_epi64(a, hi_count);
263   // Take the lower 64 bits of res_lo and upper 64 bits of res hi
264   // 1. Swap the upper and lower 64 bits of res_hi
265   __m128i tmp_hi =
266       _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
267   // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
268   return _mm_unpacklo_epi64(res_lo, tmp_hi);
269 }
270 
271 template <Path path>
272 inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) {
273   __m128i a_lo = _mm256_extractf128_si256(a, 0);
274   __m128i a_hi = _mm256_extractf128_si256(a, 1);
275   __m128i b_lo = _mm256_extractf128_si256(b, 0);
276   __m128i b_hi = _mm256_extractf128_si256(b, 1);
277   __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo);
278   __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi);
279   return _mm256_set_m128i(hi, lo);
280 }
281 
282 #define PermuteM128_epi32(a, imm) \
283   _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm));
284 
285 inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) {
286   // shift all elements of a by first 32bits of b.
287   __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1));
288 
289   // put bits 32-63 of b in the first slot.
290   __m128i tmp1 = PermuteM128_epi32(b, 1);
291   // put bits 32-63 of a in the first slot.
292   __m128i a1 = PermuteM128_epi32(a, 1);
293   // shift all elements of a by second 32bits of b.
294   __m128i res1 =
295       _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1));
296 
297   // put bits 64-95 of b in the first slot.
298   __m128i tmp2 = PermuteM128_epi32(b, 2);
299   // shift all elements of a by third 32bits of b.
300   __m128i res2 =
301       _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1));
302 
303   // put bits 96-127 of b in the first slot.
304   __m128i tmp3 = PermuteM128_epi32(b, 3);
305   // put bits 96-127 of a in the third slot.
306   __m128i a3 = PermuteM128_epi32(a, 48);
307   // shift all elements of a3 by fourth 32bits of b.
308   __m128i res3 =
309       _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1));
310 
311   // Take bits 0-31 of res0, bits 0-31 of res1,
312   // bits 64-95 of res2, and bits 64-95 of res3.
313   // res0 _ _ _ 0
314   // res1 _ _ _ 1
315   // res2 _ 2 _ _
316   // res3 _ 3 _ _
317   // f_01 _ _ 1 0
318   // f_23 _ _ 3 2
319   __m128i f_01 = _mm_unpacklo_epi32(res0, res1);
320   __m128i f_23 = _mm_unpackhi_epi32(res2, res3);
321   // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
322   return _mm_unpacklo_epi64(f_01, f_23);
323 }
324 
325 template <Path path>
326 inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) {
327   __m128i a_lo = _mm256_extractf128_si256(a, 0);
328   __m128i a_hi = _mm256_extractf128_si256(a, 1);
329   __m128i b_lo = _mm256_extractf128_si256(b, 0);
330   __m128i b_hi = _mm256_extractf128_si256(b, 1);
331   __m128i lo = mm_sllv_epi32(a_lo, b_lo);
332   __m128i hi = mm_sllv_epi32(a_hi, b_hi);
333   return _mm256_set_m128i(hi, lo);
334 }
335 
336 template <Path path>
337 inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) {
338   __m128i a_lo = _mm256_extractf128_si256(a, 0);
339   __m128i a_hi = _mm256_extractf128_si256(a, 1);
340   __m128i b_lo = _mm256_extractf128_si256(b, 0);
341   __m128i b_hi = _mm256_extractf128_si256(b, 1);
342   __m128i lo = _mm_sub_epi32(a_lo, b_lo);
343   __m128i hi = _mm_sub_epi32(a_hi, b_hi);
344   return _mm256_set_m128i(hi, lo);
345 }
346 
347 template <Path path>
348 inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) {
349   __m128i a_lo = _mm256_extractf128_si256(a, 0);
350   __m128i a_hi = _mm256_extractf128_si256(a, 1);
351   __m128i b_lo = _mm256_extractf128_si256(b, 0);
352   __m128i b_hi = _mm256_extractf128_si256(b, 1);
353   __m128i lo = _mm_mul_epi32(a_lo, b_lo);
354   __m128i hi = _mm_mul_epi32(a_hi, b_hi);
355   return _mm256_set_m128i(hi, lo);
356 }
357 
358 // Perform the equivalent of mm256_permutevar8x32 with
359 // a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
360 template <Path path>
361 inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
362   // a_lo = 3 2 1 0
363   __m128i a_lo = _mm256_extractf128_si256(a, 0);
364   // a_hi = 7 6 5 4
365   __m128i a_hi = _mm256_extractf128_si256(a, 1);
366   // shuffle a_lo to get 3 1 2 0
367   __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
368   // shuffle a_hi to get 7 5 6 4
369   __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
370   // unpack lo 64 of res_lo and res hi to get 6 4 2 0
371   __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
372   // unpack hi 64 of res_lo and res hi to get 7 5 1 3
373   __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
374   return _mm256_set_m128i(res_hi, res_lo);
375 }
376 
377 template <Path path>
378 inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) {
379   const __m256i bias0 = _mm256_set1_epi32(*(bias + offset));
380   return mm256_add_epi32<path>(a, bias0);
381 }
382 
383 __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
384                            const __m256i& mask) {
385   __m256 result =
386       _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
387                        _mm256_castsi256_ps(mask));
388   return _mm256_castps_si256(result);
389 }
390 
391 template <Path path>
392 inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) {
393   __m128i a_lo = _mm256_extractf128_si256(a, 0);
394   __m128i a_hi = _mm256_extractf128_si256(a, 1);
395   __m128i b_lo = _mm256_extractf128_si256(b, 0);
396   __m128i b_hi = _mm256_extractf128_si256(b, 1);
397   __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo);
398   __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi);
399   return _mm256_set_m128i(hi, lo);
400 }
401 
402 template <Path path>
403 inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) {
404   __m128i a_lo = _mm256_extractf128_si256(a, 0);
405   __m128i a_hi = _mm256_extractf128_si256(a, 1);
406 
407   __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0));
408   __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1));
409   __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2));
410   __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3));
411   __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4));
412   __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5));
413   __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6));
414   __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7));
415 
416   // get element 0 from r0, element 1 from r1
417   __m128i r01 = BlendM128_epi32(r0, r1, 2);
418   // get element 2 from r2, element 3 from r3
419   __m128i r23 = BlendM128_epi32(r2, r3, 8);
420   // get element 0 from r4, element 1 from r5
421   __m128i r45 = BlendM128_epi32(r4, r5, 2);
422   // get element 2 from r6, element 3 from r7
423   __m128i r67 = BlendM128_epi32(r6, r7, 8);
424   // get lower 64 bits of r01, upper 64 bits of r23
425   __m128i r0123 = BlendM128_epi64(r01, r23, 2);
426   // get lower 64 bits of r45, upper 64 bits of r67
427   __m128i r4567 = BlendM128_epi64(r45, r67, 2);
428   return _mm256_set_m128i(r4567, r0123);
429 }
430 
431 // AVX doesn't have fused multiply-add so we define an inline function to be
432 // used in the common code following.
433 template <>
434 inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b,
435                                  const __m256& c) {
436   const __m256 prod = _mm256_mul_ps(a, b);
437   return _mm256_add_ps(prod, c);
438 }
439 
440 }  // namespace intrin_utils
441 }  // namespace
442 
443 template <Path path>
444 void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
445   profiler::ScopeLabel label("Kernel kAvx 8-bit");
446   const std::int8_t splitter_idx_data[32] = {
447       0, 1, 4, 5, 8,  9,  12, 13,  //
448       2, 3, 6, 7, 10, 11, 14, 15,  //
449       0, 1, 4, 5, 8,  9,  12, 13,  //
450       2, 3, 6, 7, 10, 11, 14, 15   //
451   };
452 
453   std::int32_t dst_stride = 0;
454   if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
455       (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
456     dst_stride = params.dst_stride;
457   } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
458     dst_stride = params.dst_stride / sizeof(std::int16_t);
459   } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
460     dst_stride = params.dst_stride / sizeof(std::int32_t);
461   } else {
462     RUY_DCHECK(false);
463   }
464 
465   const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
466   void* dst_col_ptr = params.dst_base_ptr;
467 
468   for (int col = params.start_col; col <= params.last_col;
469        col += kAvx8bitBlockSize) {
470     const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
471     void* dst_ptr = dst_col_ptr;
472 
473     const std::int32_t lhs_zero_point = params.lhs_zero_point;
474     const bool has_rhs_sums_offsets =
475         (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
476     std::int32_t rhs_sums_offsets[8];
477     if (has_rhs_sums_offsets) {
478       const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
479           _mm256_set1_epi32(lhs_zero_point),
480           _mm256_loadu_si256(
481               reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
482       _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
483                           rhs_sums_offset_v);
484     }
485 
486     for (int row = params.start_row; row <= params.last_row;
487          row += kAvx8bitBlockSize) {
488       int channel =
489           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
490       int multiplier_channel =
491           (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
492       const int residual_rows =
493           std::min(params.dst_rows - row, kAvx8bitBlockSize);
494       const int residual_cols =
495           std::min(params.dst_cols - col, kAvx8bitBlockSize);
496 
497       const __m256i splitter_idx = _mm256_loadu_si256(
498           reinterpret_cast<__m256i const*>(splitter_idx_data));
499 
500       __m256i accum_data_v0;
501       __m256i accum_data_v1;
502       __m256i accum_data_v2;
503       __m256i accum_data_v3;
504       __m256i accum_data_v4;
505       __m256i accum_data_v5;
506       __m256i accum_data_v6;
507       __m256i accum_data_v7;
508 
509       // initial_accum_data will be the initialize of each of the
510       // accum_data_* accumulator registers. We compute into it terms that are
511       // identical across columns.
512       __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth);
513       __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth);
514 
515       // In the channels-are-rows case, we can load bias here.
516       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
517           !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
518         initial_accum_data_lo = _mm_add_epi32(
519             initial_accum_data_lo,
520             _mm_loadu_si128(
521                 reinterpret_cast<const __m128i*>(params.bias + row)));
522         initial_accum_data_hi = _mm_add_epi32(
523             initial_accum_data_hi,
524             _mm_loadu_si128(
525                 reinterpret_cast<const __m128i*>(params.bias + row + 4)));
526       }
527 
528       // Adjustments common across columns.
529       const std::int32_t rhs_zero_point = params.rhs_zero_point;
530       if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
531         const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point);
532         const __m128i lhs_sums_offset_lo = _mm_mullo_epi32(
533             rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
534                         &params.lhs_sums[row])));
535         const __m128i lhs_sums_offset_hi = _mm_mullo_epi32(
536             rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
537                         &params.lhs_sums[row + 4])));
538 
539         initial_accum_data_lo =
540             _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo);
541         initial_accum_data_hi =
542             _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi);
543       }
544 
545       // Adjustments differing across columns.
546       if (has_rhs_sums_offsets) {
547         __m256i initial_accum_data =
548             _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
549 
550         accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
551             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
552         accum_data_v1 = intrin_utils::mm256_sub_epi32<path>(
553             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
554         accum_data_v2 = intrin_utils::mm256_sub_epi32<path>(
555             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
556         accum_data_v3 = intrin_utils::mm256_sub_epi32<path>(
557             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
558         accum_data_v4 = intrin_utils::mm256_sub_epi32<path>(
559             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
560         accum_data_v5 = intrin_utils::mm256_sub_epi32<path>(
561             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
562         accum_data_v6 = intrin_utils::mm256_sub_epi32<path>(
563             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
564         accum_data_v7 = intrin_utils::mm256_sub_epi32<path>(
565             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
566       } else {
567         __m256i initial_accum_data =
568             _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
569         accum_data_v0 = initial_accum_data;
570         accum_data_v1 = initial_accum_data;
571         accum_data_v2 = initial_accum_data;
572         accum_data_v3 = initial_accum_data;
573         accum_data_v4 = initial_accum_data;
574         accum_data_v5 = initial_accum_data;
575         accum_data_v6 = initial_accum_data;
576         accum_data_v7 = initial_accum_data;
577       }
578 
579       // Finally, in the channels-are-columns case, load bias data here.
580       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
581           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
582         accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0,
583                                                          params.bias + col, 0);
584         accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1,
585                                                          params.bias + col, 1);
586         accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2,
587                                                          params.bias + col, 2);
588         accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3,
589                                                          params.bias + col, 3);
590         accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4,
591                                                          params.bias + col, 4);
592         accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5,
593                                                          params.bias + col, 5);
594         accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6,
595                                                          params.bias + col, 6);
596         accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7,
597                                                          params.bias + col, 7);
598       }
599 
600       const std::int8_t* lhs_ptr = lhs_col_ptr;
601       const std::int8_t* rhs_ptr = rhs_col_ptr;
602       for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
603         const __m256i lhs_data =
604             _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
605         const __m256i rhs_data_8bit =
606             _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
607 
608         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
609         std::int32_t rhs_data[16];
610         const __m128i rhs_data_bottom_lane =
611             _mm256_castsi256_si128(rhs_data_8bit);
612         const __m128i rhs_data_top_lane =
613             _mm256_extractf128_si256(rhs_data_8bit, 1);
614         const __m256i rhs_16_bit_dup_low =
615             intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane);
616         const __m256i rhs_16_bit_dup_high =
617             intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane);
618         // Now that we have cast the RHS data, we store it so that each value
619         // can be separately loaded in the accumulation loop.
620         _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
621                             rhs_16_bit_dup_low);
622         _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
623                             rhs_16_bit_dup_high);
624 
625         // NOTE: There may be opportunities for permuting the data in the
626         // packing code instead of here.
627         const __m256i lhs_data_split =
628             intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
629         const __m256i lhs_data_split_expand_bottom =
630             intrin_utils::mm256_cvtepi8_epi16<path>(
631                 _mm256_extractf128_si256(lhs_data_split, 0));
632         const __m256i lhs_data_split_expand_top =
633             intrin_utils::mm256_cvtepi8_epi16<path>(
634                 _mm256_extractf128_si256(lhs_data_split, 1));
635 
636         // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
637         const __m256i lhs_16_bit_low =
638             intrin_utils::mm256_permute2x128_si256<path>(
639                 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
640         // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
641         const __m256i lhs_16_bit_high =
642             intrin_utils::mm256_permute2x128_si256<path>(
643                 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
644 
645         __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
646             rhs_data));  // Load [0 1 2 3 4 5 6 7]
647         __m256i rhs1 = _mm256_lddqu_si256(
648             reinterpret_cast<const __m256i*>(rhs_data + 8));  // Load [8 - 15]
649         __m256i rhs0_3 =
650             _mm256_permute2f128_si256(rhs0, rhs0, 0);  // [0 1 2 3 0 1 2 3]
651         __m256i rhs4_7 =
652             _mm256_permute2f128_si256(rhs0, rhs0, 0x11);  // [4 5 6 7 4 5 6 7]
653         __m256i rhs8_11 =
654             _mm256_permute2f128_si256(rhs1, rhs1, 0);  // [8 9 10 11 8 9 10 11]
655         __m256i rhs12_15 =
656             _mm256_permute2f128_si256(rhs1, rhs1, 17);  // [12 - 15, 12 - 15]
657 
658         auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
659                                   __m256i& accum) {
660           // Perform mul-adds on low and high components of accum separately.
661           __m128i accum_lo = _mm256_extractf128_si256(accum, 0);
662           __m128i accum_hi = _mm256_extractf128_si256(accum, 1);
663 
664           __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0);
665           __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1);
666           __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0);
667           __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1);
668           __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0);
669           __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1);
670 
671           accum_lo = _mm_add_epi32(accum_lo, lo_0);
672           accum_hi = _mm_add_epi32(accum_hi, lo_1);
673 
674           __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0);
675           __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1);
676           __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0);
677           __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1);
678           __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0);
679           __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1);
680 
681           accum_lo = _mm_add_epi32(accum_lo, hi_0);
682           accum_hi = _mm_add_epi32(accum_hi, hi_1);
683           accum = _mm256_set_m128i(accum_hi, accum_lo);
684         };
685         __m256i tmp0, tmp1, tmp2, tmp3;
686         __m128i lo0, lo1, hi0, hi1;
687         mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0);
688         mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55);
689         process_column(tmp0, tmp1, accum_data_v0);
690         mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa);
691         mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff);
692         process_column(tmp2, tmp3, accum_data_v1);
693 
694         mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0);
695         mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55);
696         process_column(tmp0, tmp1, accum_data_v2);
697         mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa);
698         mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff);
699         process_column(tmp2, tmp3, accum_data_v3);
700 
701         mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0);
702         mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55);
703         process_column(tmp0, tmp1, accum_data_v4);
704         mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa);
705         mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff);
706         process_column(tmp2, tmp3, accum_data_v5);
707 
708         mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0);
709         mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55);
710         process_column(tmp0, tmp1, accum_data_v6);
711         mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa);
712         mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff);
713         process_column(tmp2, tmp3, accum_data_v7);
714 
715         lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
716         rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
717       }
718 
719       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
720         __m256i m_vector;
721         __m256i e_vector;
722         // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
723         m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
724             params.multiplier_fixedpoint + multiplier_channel));
725         e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
726             params.multiplier_exponent + multiplier_channel));
727 
728         const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
729             _mm256_extractf128_si256(m_vector, 0));
730         const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
731             _mm256_extractf128_si256(m_vector, 1));
732 
733         const __m256i zero_vector = _mm256_setzero_si256();
734         const __m256i left_shift =
735             intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
736         const __m256i neg_e_vector =
737             intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
738         const __m256i right_shift =
739             intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
740         const __m256i final_right_shift = _mm256_set1_epi32(31);
741         const __m256i final_right_shift_low =
742             intrin_utils::mm256_cvtepi32_epi64<path>(
743                 _mm256_extractf128_si256(final_right_shift, 0));
744         const __m256i final_right_shift_high =
745             intrin_utils::mm256_cvtepi32_epi64<path>(
746                 _mm256_extractf128_si256(final_right_shift, 1));
747         const __m256i convert_to_unsigned_64 =
748             _mm256_set1_epi64x(0x8000000000000000);
749 
750         __m256i post_scaling_offset = _mm256_setzero_si256();
751 
752         // A "half" added for rounding prior to truncation of 64-bit value.
753         const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
754             intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
755             convert_to_unsigned_64);
756 
757         if (params.dst_zero_point) {
758           post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
759         }
760 
761         // We cannot do
762         //
763         // scaled_v_low =
764         //     _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
765         // scaled_v_high =
766         //     _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
767         //
768         // since this instruction is not in AVX2. Instead we use
769         // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
770         // offsets before (convert_to_unsigned_64) and after
771         // (convert_to_signed_halved).
772         //
773         // The overall process is, for 64-bit scaled accumulator:
774         // unsigned_accum = signed_accum + 1 << 63;
775         // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
776         // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
777 
778         // There are various ways to repack the results, in the absence of
779         // _mm256_cvtepi64_epi32() or anything like it.
780         // A.
781         // accum_data_v[j] =
782         //     _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
783         //                      _mm256_extract_epi32(scaled_v_high, 4),
784         //                      _mm256_extract_epi32(scaled_v_high, 2),
785         //                      _mm256_extract_epi32(scaled_v_high, 0),
786         //                      _mm256_extract_epi32(scaled_v_low, 6),
787         //                      _mm256_extract_epi32(scaled_v_low, 4),
788         //                      _mm256_extract_epi32(scaled_v_low, 2),
789         //                      _mm256_extract_epi32(scaled_v_low, 0));
790         // B.
791         // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
792         // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
793         // accum_data_v[j] =
794         //     _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
795         //                       _mm256_extract_epi64(scaled_v_high, 0),
796         //                       _mm256_extract_epi64(scaled_v_low, 2),
797         //                       _mm256_extract_epi64(scaled_v_low, 0));
798         // C.
799         // scaled_v_low =
800         //     _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
801         // scaled_v_high =
802         //     _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
803         // accum_data_v[j] =
804         //     _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
805         //
806         // However, we choose the following because it uses two lighter
807         // instructions. The permutation does have a longer latency, but this
808         // loop can be unrolled.
809         // D.
810         // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
811         // __m256i results =
812         //     _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
813         // results = _mm256_permutevar8x32_epi32(results, repack_perm);
814         // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results,
815         // post_scaling_offset);
816 
817         // This multiplier code is complex and expensive enough on x86, that
818         // we prefer to implement the channels-are-columns case by transposing
819         // around it, rather than duplicate it (which would also require
820         // duplicating the above code computing the multiplier constants).
821         // This is one instance where channels-are-columns has lower performance
822         // than channels-are-rows.
823         const bool transpose_around_multiplier =
824             (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
825             (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
826         if (transpose_around_multiplier) {
827           // Transpose the 8x8 accumulators block. Will be un-transposed below
828           // after the multplier implementation.
829           intrin_utils::mm256_transpose8x8_epi32<path>(
830               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
831               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
832         }
833 
834         auto rounding_right_shift = [=](__m256i& results,
835                                         const __m256i& exponent) {
836           // Construct the "nudge" value for each lane if the exponent is
837           // greater than 0. Otherwise, the nudge is 0.
838           const __m256i zeros = _mm256_setzero_si256();
839           const __m256i mask_rightshift_gtz =
840               intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros);
841           const __m256i one_shift_exp_minus1 =
842               intrin_utils::mm256_sllv_epi32<path>(
843                   _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
844                                             exponent, _mm256_set1_epi32(1)));
845           __m256i nudge = intrin_utils::mm256_blendv_epi32(
846               zeros, one_shift_exp_minus1, mask_rightshift_gtz);
847           // Calculate the shifted sum (results + nudge) >> exp.
848           const __m256i r_plus_nudge =
849               intrin_utils::mm256_add_epi32<path>(results, nudge);
850           const __m256i shifted_sum =
851               intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent);
852 
853           // Identify overflow in each lane and create mask.
854           const __m256i one_shift_31minus_exp =
855               intrin_utils::mm256_sllv_epi32<path>(
856                   _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
857                                             _mm256_set1_epi32(31), exponent));
858           const __m256i mask_num_plus_nudge_overflow =
859               intrin_utils::mm256_cmpgt_epi32<path>(
860                   results, intrin_utils::mm256_sub_epi32<path>(
861                                _mm256_set1_epi32(0x7fffffff), nudge));
862           // Fill results with either (results + nudge) >> exponent or
863           // 1 << (31 - exp) in the case of overflow.
864           results = intrin_utils::mm256_blendv_epi32(
865               shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
866         };
867 
868         auto apply_multiplier = [=](__m256i& accum) {
869           __m256i shifted_accum =
870               intrin_utils::mm256_sllv_epi32<path>(accum, left_shift);
871           // Apply the fixed-point part of the multiplier.
872           __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
873               intrin_utils::mm256_cvtepi32_epi64<path>(
874                   _mm256_extractf128_si256(shifted_accum, 0)),
875               m_64bit_low);
876           __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
877               intrin_utils::mm256_cvtepi32_epi64<path>(
878                   _mm256_extractf128_si256(shifted_accum, 1)),
879               m_64bit_high);
880           scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
881                                                              offset_vector);
882           scaled_v_high = intrin_utils::mm256_add_epi64<path>(
883               scaled_v_high, offset_vector);
884 
885           scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
886               scaled_v_low, final_right_shift_low);
887           scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
888               scaled_v_high, final_right_shift_high);
889           scaled_v_high =
890               intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
891           __m256i results;
892           mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
893           // Permute results to this ordering of int32 elements
894           // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
895           results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
896 
897           rounding_right_shift(results, right_shift);
898           accum =
899               intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
900         };
901         apply_multiplier(accum_data_v0);
902         apply_multiplier(accum_data_v1);
903         apply_multiplier(accum_data_v2);
904         apply_multiplier(accum_data_v3);
905         apply_multiplier(accum_data_v4);
906         apply_multiplier(accum_data_v5);
907         apply_multiplier(accum_data_v6);
908         apply_multiplier(accum_data_v7);
909         // See above comment: here we transpose again to undo the transposition
910         // of the 8x8 block of accumulators used to implement the
911         // channels-are-columns case.
912         if (transpose_around_multiplier) {
913           intrin_utils::mm256_transpose8x8_epi32<path>(
914               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
915               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
916         }
917       }
918       const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
919       const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
920       const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
921                                     (residual_cols == kAvx8bitBlockSize);
922 
923       __m256i accum_data_v[kAvx8bitBlockSize];
924       if (!store_full_block) {
925         accum_data_v[0] = accum_data_v0;
926         accum_data_v[1] = accum_data_v1;
927         accum_data_v[2] = accum_data_v2;
928         accum_data_v[3] = accum_data_v3;
929         accum_data_v[4] = accum_data_v4;
930         accum_data_v[5] = accum_data_v5;
931         accum_data_v[6] = accum_data_v6;
932         accum_data_v[7] = accum_data_v7;
933       }
934 
935       if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
936         std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
937         if (store_full_block) {
938           accum_data_v0 =
939               intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
940           accum_data_v0 =
941               intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
942           accum_data_v1 =
943               intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
944           accum_data_v1 =
945               intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
946           accum_data_v2 =
947               intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
948           accum_data_v2 =
949               intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
950           accum_data_v3 =
951               intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
952           accum_data_v3 =
953               intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
954           accum_data_v4 =
955               intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
956           accum_data_v4 =
957               intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
958           accum_data_v5 =
959               intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
960           accum_data_v5 =
961               intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
962           accum_data_v6 =
963               intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
964           accum_data_v6 =
965               intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
966           accum_data_v7 =
967               intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
968           accum_data_v7 =
969               intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
970           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
971               &tmp_ptr[0 * dst_stride], accum_data_v0);
972           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
973               &tmp_ptr[1 * dst_stride], accum_data_v1);
974           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
975               &tmp_ptr[2 * dst_stride], accum_data_v2);
976           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
977               &tmp_ptr[3 * dst_stride], accum_data_v3);
978           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
979               &tmp_ptr[4 * dst_stride], accum_data_v4);
980           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
981               &tmp_ptr[5 * dst_stride], accum_data_v5);
982           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
983               &tmp_ptr[6 * dst_stride], accum_data_v6);
984           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
985               &tmp_ptr[7 * dst_stride], accum_data_v7);
986         } else {
987           for (int j = 0; j < residual_cols; ++j) {
988             __m256i result = accum_data_v[j];
989             result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
990             result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
991             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
992                 tmp_ptr, residual_rows, result);
993             tmp_ptr += dst_stride;
994           }
995         }
996         dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
997                                      kAvx8bitBlockSize);
998       } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
999         std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1000         if (store_full_block) {
1001           accum_data_v0 =
1002               intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1003           accum_data_v0 =
1004               intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1005           accum_data_v1 =
1006               intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1007           accum_data_v1 =
1008               intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1009           accum_data_v2 =
1010               intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1011           accum_data_v2 =
1012               intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1013           accum_data_v3 =
1014               intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1015           accum_data_v3 =
1016               intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1017           accum_data_v4 =
1018               intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1019           accum_data_v4 =
1020               intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1021           accum_data_v5 =
1022               intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1023           accum_data_v5 =
1024               intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1025           accum_data_v6 =
1026               intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1027           accum_data_v6 =
1028               intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1029           accum_data_v7 =
1030               intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1031           accum_data_v7 =
1032               intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1033           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
1034                                                          accum_data_v0);
1035           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
1036                                                          accum_data_v1);
1037           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1038               &tmp_ptr[2 * dst_stride], accum_data_v2);
1039           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1040               &tmp_ptr[3 * dst_stride], accum_data_v3);
1041           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1042               &tmp_ptr[4 * dst_stride], accum_data_v4);
1043           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1044               &tmp_ptr[5 * dst_stride], accum_data_v5);
1045           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1046               &tmp_ptr[6 * dst_stride], accum_data_v6);
1047           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1048               &tmp_ptr[7 * dst_stride], accum_data_v7);
1049         } else {
1050           for (int j = 0; j < residual_cols; ++j) {
1051             __m256i result = accum_data_v[j];
1052             result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1053             result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1054             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
1055                 tmp_ptr, residual_rows, result);
1056             tmp_ptr += dst_stride;
1057           }
1058         }
1059         dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1060                                      kAvx8bitBlockSize);
1061       } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1062         std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1063         if (store_full_block) {
1064           accum_data_v0 =
1065               intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1066           accum_data_v0 =
1067               intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1068           accum_data_v1 =
1069               intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1070           accum_data_v1 =
1071               intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1072           accum_data_v2 =
1073               intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1074           accum_data_v2 =
1075               intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1076           accum_data_v3 =
1077               intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1078           accum_data_v3 =
1079               intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1080           accum_data_v4 =
1081               intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1082           accum_data_v4 =
1083               intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1084           accum_data_v5 =
1085               intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1086           accum_data_v5 =
1087               intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1088           accum_data_v6 =
1089               intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1090           accum_data_v6 =
1091               intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1092           accum_data_v7 =
1093               intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1094           accum_data_v7 =
1095               intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1096           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
1097                                                           accum_data_v0);
1098           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
1099                                                           accum_data_v1);
1100           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1101               &tmp_ptr[2 * dst_stride], accum_data_v2);
1102           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1103               &tmp_ptr[3 * dst_stride], accum_data_v3);
1104           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1105               &tmp_ptr[4 * dst_stride], accum_data_v4);
1106           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1107               &tmp_ptr[5 * dst_stride], accum_data_v5);
1108           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1109               &tmp_ptr[6 * dst_stride], accum_data_v6);
1110           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1111               &tmp_ptr[7 * dst_stride], accum_data_v7);
1112         } else {
1113           for (int j = 0; j < residual_cols; ++j) {
1114             __m256i result = accum_data_v[j];
1115             result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1116             result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1117             intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
1118                 tmp_ptr, residual_rows, result);
1119             tmp_ptr += dst_stride;
1120           }
1121         }
1122         dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1123                                      kAvx8bitBlockSize);
1124       } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1125         if (store_full_block) {
1126           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
1127           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
1128           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
1129                                                  accum_data_v1);
1130           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
1131                                                  accum_data_v2);
1132           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
1133                                                  accum_data_v3);
1134           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
1135                                                  accum_data_v4);
1136           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
1137                                                  accum_data_v5);
1138           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
1139                                                  accum_data_v6);
1140           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
1141                                                  accum_data_v7);
1142         } else {
1143           std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1144           for (int j = 0; j < residual_cols; ++j) {
1145             intrin_utils::mm256_n_storeu_epi32<path>(
1146                 dst_block_ptr, residual_rows, accum_data_v[j]);
1147             dst_block_ptr += dst_stride;
1148           }
1149         }
1150         dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1151                                      kAvx8bitBlockSize);
1152       } else {
1153         RUY_DCHECK(false);
1154       }
1155 
1156       lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1157     }  // End row-block loop.
1158 
1159     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1160                                      kAvx8bitBlockSize * params.dst_stride);
1161     rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1162   }  // End col-block loop.
1163 }  // NOLINT(readability/fn_size)
1164 
1165 void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) {
1166   Kernel8bitAvxImpl<Path::kAvx>(params);
1167 }
1168 
1169 template <Path path>
1170 void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
1171   profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV");
1172 
1173   RUY_DCHECK_EQ(params.dst_cols, 1);
1174   RUY_DCHECK_EQ(params.last_col, 0);
1175   RUY_DCHECK_EQ(params.start_col, 0);
1176 
1177   const std::int8_t splitter_idx_data[32] = {
1178       0, 1, 4, 5, 8,  9,  12, 13,  //
1179       2, 3, 6, 7, 10, 11, 14, 15,  //
1180       0, 1, 4, 5, 8,  9,  12, 13,  //
1181       2, 3, 6, 7, 10, 11, 14, 15   //
1182   };
1183 
1184   int bias_ptr_block_increment =
1185       params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
1186 
1187   const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1188   void* dst_col_ptr = params.dst_base_ptr;
1189   const std::int32_t* bias_col_ptr = params.bias;
1190   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
1191     bias_col_ptr += params.start_row;
1192   }
1193 
1194   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1195   void* dst_ptr = dst_col_ptr;
1196   const std::int32_t* bias_ptr = bias_col_ptr;
1197 
1198   const std::int32_t lhs_zero_point = params.lhs_zero_point;
1199   const bool has_rhs_sums_offsets =
1200       (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
1201   std::int32_t rhs_sums_offsets[8];
1202   if (has_rhs_sums_offsets) {
1203     const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
1204         _mm256_set1_epi32(lhs_zero_point),
1205         _mm256_loadu_si256(
1206             reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
1207     _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
1208                         rhs_sums_offset_v);
1209   }
1210 
1211   for (int row = params.start_row; row <= params.last_row;
1212        row += kAvx8bitBlockSize) {
1213     const int residual_rows =
1214         std::min(params.dst_rows - row, kAvx8bitBlockSize);
1215 
1216     const __m256i splitter_idx =
1217         _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
1218 
1219     __m256i accum_data_v0;
1220 
1221     // Initialize with bias.
1222     __m256i initial_accum_data =
1223         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
1224     bias_ptr += bias_ptr_block_increment;
1225 
1226     // Adjustments common across columns.
1227     const std::int32_t rhs_zero_point = params.rhs_zero_point;
1228     if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
1229       const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>(
1230           _mm256_set1_epi32(rhs_zero_point),
1231           _mm256_loadu_si256(
1232               reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
1233       initial_accum_data = intrin_utils::mm256_sub_epi32<path>(
1234           initial_accum_data, lhs_sums_offset);
1235     }
1236     const std::int32_t prod_zp_depth = params.prod_zp_depth;
1237     if (prod_zp_depth) {
1238       initial_accum_data = intrin_utils::mm256_add_epi32<path>(
1239           initial_accum_data, _mm256_set1_epi32(prod_zp_depth));
1240     }
1241 
1242     // Adjustments differing across columns.
1243     if (has_rhs_sums_offsets) {
1244       accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
1245           initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
1246     } else {
1247       accum_data_v0 = initial_accum_data;
1248     }
1249 
1250     const std::int8_t* lhs_ptr = lhs_col_ptr;
1251     const std::int8_t* rhs_ptr = rhs_col_ptr;
1252     for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
1253       const __m256i lhs_data =
1254           _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
1255       const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
1256 
1257       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
1258       // For simplicity we load 4x the data that we need and process twice the
1259       // data  that we need  and store only the data we need.
1260       std::int32_t rhs_data[2];
1261       const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
1262       // Now that we have cast the RHS data, we store it so that each value
1263       // can be separately loaded in the accumulation loop.
1264       _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
1265 
1266       // NOTE: There may be opportunities for permuting the data in the packing
1267       // code instead of here.
1268       const __m256i lhs_data_split =
1269           intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
1270       const __m256i lhs_data_split_expand_bottom =
1271           intrin_utils::mm256_cvtepi8_epi16<path>(
1272               _mm256_extractf128_si256(lhs_data_split, 0));
1273       const __m256i lhs_data_split_expand_top =
1274           intrin_utils::mm256_cvtepi8_epi16<path>(
1275               _mm256_extractf128_si256(lhs_data_split, 1));
1276 
1277       // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
1278       const __m256i lhs_16_bit_low =
1279           intrin_utils::mm256_permute2x128_si256<path>(
1280               lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
1281       // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
1282       const __m256i lhs_16_bit_high =
1283           intrin_utils::mm256_permute2x128_si256<path>(
1284               lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
1285       // Accumulate for column 0.
1286       const std::int32_t low_rhs_value = rhs_data[0];
1287       const std::int32_t high_rhs_value = rhs_data[1];
1288 
1289       const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
1290       const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
1291 
1292       accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1293           accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1294                              lhs_16_bit_low, rhs_16_bit_dup_low));
1295       accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1296           accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1297                              lhs_16_bit_high, rhs_16_bit_dup_high));
1298 
1299       lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1300       rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1301     }
1302 
1303     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
1304       __m256i m_vector;
1305       __m256i e_vector;
1306       // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
1307       int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
1308       m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1309           params.multiplier_fixedpoint + channel));
1310       e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1311           params.multiplier_exponent + channel));
1312 
1313       const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
1314           _mm256_extractf128_si256(m_vector, 0));
1315       const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
1316           _mm256_extractf128_si256(m_vector, 1));
1317 
1318       const __m256i zero_vector = _mm256_setzero_si256();
1319       const __m256i left_shift =
1320           intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
1321       const __m256i neg_e_vector =
1322           intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
1323       const __m256i right_shift =
1324           intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
1325       const __m256i final_right_shift = _mm256_set1_epi32(31);
1326       const __m256i final_right_shift_low =
1327           intrin_utils::mm256_cvtepi32_epi64<path>(
1328               _mm256_extractf128_si256(final_right_shift, 0));
1329       const __m256i final_right_shift_high =
1330           intrin_utils::mm256_cvtepi32_epi64<path>(
1331               _mm256_extractf128_si256(final_right_shift, 1));
1332       const __m256i convert_to_unsigned_64 =
1333           _mm256_set1_epi64x(0x8000000000000000);
1334 
1335       __m256i post_scaling_offset = _mm256_setzero_si256();
1336 
1337       // A "half" added for rounding prior to truncation of 64-bit value.
1338       const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
1339           intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
1340           convert_to_unsigned_64);
1341 
1342       if (params.dst_zero_point) {
1343         post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
1344       }
1345 
1346       // See GEMM version for details of this process.
1347       {
1348         __m256i shifted_accum =
1349             intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift);
1350         // Apply the fixed-point part of the multiplier.
1351         __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
1352             intrin_utils::mm256_cvtepi32_epi64<path>(
1353                 _mm256_extractf128_si256(shifted_accum, 0)),
1354             m_64bit_low);
1355         __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
1356             intrin_utils::mm256_cvtepi32_epi64<path>(
1357                 _mm256_extractf128_si256(shifted_accum, 1)),
1358             m_64bit_high);
1359 
1360         scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
1361                                                            offset_vector);
1362         scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high,
1363                                                             offset_vector);
1364 
1365         scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
1366             scaled_v_low, final_right_shift_low);
1367         scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
1368             scaled_v_high, final_right_shift_high);
1369 
1370         scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
1371         __m256i results;
1372         mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
1373         // Permute results to this ordering of int32 elements
1374         // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
1375         results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
1376 
1377         // Now perform the Rounding Right Shift.
1378         // First, construct the "nudge" value for each lane if the exponent is
1379         // greater than 0. Otherwise, the nudge is 0.
1380         const __m256i zeros = _mm256_setzero_si256();
1381         const __m256i mask_rightshift_gtz =
1382             intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros);
1383         const __m256i one_shift_exp_minus1 =
1384             intrin_utils::mm256_sllv_epi32<path>(
1385                 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1386                                           right_shift, _mm256_set1_epi32(1)));
1387         __m256i nudge = intrin_utils::mm256_blendv_epi32(
1388             zeros, one_shift_exp_minus1, mask_rightshift_gtz);
1389         // Calculate the shifted sum (results + nudge) >> exp.
1390         const __m256i r_plus_nudge =
1391             intrin_utils::mm256_add_epi32<path>(results, nudge);
1392         const __m256i shifted_sum =
1393             intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift);
1394 
1395         // Identify overflow in each lane and create mask.
1396         const __m256i one_shift_31minus_exp =
1397             intrin_utils::mm256_sllv_epi32<path>(
1398                 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1399                                           _mm256_set1_epi32(31), right_shift));
1400         const __m256i mask_num_plus_nudge_overflow =
1401             intrin_utils::mm256_cmpgt_epi32<path>(
1402                 results, intrin_utils::mm256_sub_epi32<path>(
1403                              _mm256_set1_epi32(0x7fffffff), nudge));
1404         // Fill results with either (results + nudge) >> exponent or
1405         // 1 << (31 - exp) in the case of overflow.
1406         results = intrin_utils::mm256_blendv_epi32(
1407             shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
1408         accum_data_v0 =
1409             intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
1410       }
1411     }
1412     const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
1413     const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
1414 
1415     if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
1416       std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
1417       __m256i result = accum_data_v0;
1418       result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1419       result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1420       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1421                                                        result);
1422       dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
1423                                    kAvx8bitBlockSize);
1424     } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
1425       std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1426       __m256i result = accum_data_v0;
1427       result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1428       result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1429       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1430                                                        result);
1431       dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1432                                    kAvx8bitBlockSize);
1433     } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1434       std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1435       __m256i result = accum_data_v0;
1436       result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1437       result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1438       intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
1439                                                         result);
1440       dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1441                                    kAvx8bitBlockSize);
1442     } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1443       std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1444       intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
1445                                                accum_data_v0);
1446       dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1447                                    kAvx8bitBlockSize);
1448     } else {
1449       RUY_DCHECK(false);
1450     }
1451 
1452     lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1453   }  // End row-block loop.
1454 
1455   dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1456                                    kAvx8bitBlockSize * params.dst_stride);
1457   rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1458 }  // NOLINT(readability/fn_size)
1459 
1460 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) {
1461   Kernel8bitAvxSingleColImpl<Path::kAvx>(params);
1462 }
1463 
1464 void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) {
1465   profiler::ScopeLabel label("Kernel kAvx float");
1466   KernelFloatAvxCommon<Path::kAvx>(params);
1467 }
1468 
1469 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) {
1470   profiler::ScopeLabel label("Kernel kAvx float GEMV");
1471   KernelFloatAvxCommonSingleCol<Path::kAvx>(params);
1472 }
1473 
1474 #endif  //  RUY_PLATFORM_AVX && RUY_OPT(ASM)
1475 
1476 }  // namespace ruy
1477