• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <cstring>
19 
20 #include "ruy/check_macros.h"
21 #include "ruy/kernel_common.h"
22 #include "ruy/kernel_x86.h"
23 #include "ruy/opt_set.h"
24 #include "ruy/platform.h"
25 #include "ruy/profiler/instrumentation.h"
26 
27 #if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
28 #include <immintrin.h>  // IWYU pragma: keep
29 #endif
30 
31 namespace ruy {
32 
33 #if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
34 
Kernel8bitAvx2(const KernelParams8bit<8,8> &)35 void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) {
36   // CPU-ID-based checks should disable the path that would reach this point.
37   RUY_DCHECK(false);
38 }
39 
Kernel8bitAvx2SingleCol(const KernelParams8bit<8,8> &)40 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>&) {
41   // CPU-ID-based checks should disable the path that would reach this point.
42   RUY_DCHECK(false);
43 }
44 
KernelFloatAvx2(const KernelParamsFloat<8,8> &)45 void KernelFloatAvx2(const KernelParamsFloat<8, 8>&) {
46   // CPU-ID-based checks should disable the path that would reach this point.
47   RUY_DCHECK(false);
48 }
49 
KernelFloatAvx2SingleCol(const KernelParamsFloat<8,8> &)50 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) {
51   // CPU-ID-based checks should disable the path that would reach this point.
52   RUY_DCHECK(false);
53 }
54 
55 #else  // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
56 
57 static constexpr int kAvx8bitBlockSize = 8;
58 static constexpr int kAvx8bitInnerSize = 4;
59 
60 namespace {
61 namespace intrin_utils {
62 
63 template <>
64 inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a,
65                                                   const __m256i& b) {
66   return _mm256_shuffle_epi8(a, b);
67 }
68 
69 // Make an inline function for FMA so we can share the float kernels
70 // with non-FMA code.
71 template <>
72 inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b,
73                                      const __m256& c) {
74   return _mm256_fmadd_ps(a, b, c);
75 }
76 
77 template <>
78 inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a,
79                                                        const int imm) {
80   switch (imm) {
81     case 0:
82       return _mm256_extracti128_si256(a, 0);
83     case 1:
84       return _mm256_extracti128_si256(a, 1);
85     default:
86       RUY_DCHECK_LT(imm, 2);
87       return _mm_setzero_si128();
88   }
89 }
90 
91 __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
92                            const __m256i& mask) {
93   __m256 result =
94       _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
95                        _mm256_castsi256_ps(mask));
96   return _mm256_castps_si256(result);
97 }
98 
99 }  // namespace intrin_utils
100 }  // namespace
101 
102 template <Path path>
103 void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
104   profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit");
105   const std::int8_t splitter_idx_data[32] = {
106       0, 1, 4, 5, 8,  9,  12, 13,  //
107       2, 3, 6, 7, 10, 11, 14, 15,  //
108       0, 1, 4, 5, 8,  9,  12, 13,  //
109       2, 3, 6, 7, 10, 11, 14, 15   //
110   };
111 
112   std::int32_t dst_stride = 0;
113   if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
114       (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
115     dst_stride = params.dst_stride;
116   } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
117     dst_stride = params.dst_stride / sizeof(std::int16_t);
118   } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
119     dst_stride = params.dst_stride / sizeof(std::int32_t);
120   } else {
121     RUY_DCHECK(false);
122   }
123 
124   const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
125   void* dst_col_ptr = params.dst_base_ptr;
126 
127   for (int col = params.start_col; col <= params.last_col;
128        col += kAvx8bitBlockSize) {
129     const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
130     void* dst_ptr = dst_col_ptr;
131 
132     const std::int32_t lhs_zero_point = params.lhs_zero_point;
133     const bool has_rhs_sums_offsets =
134         (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
135     std::int32_t rhs_sums_offsets[8];
136     if (has_rhs_sums_offsets) {
137       const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
138           _mm256_set1_epi32(lhs_zero_point),
139           _mm256_loadu_si256(
140               reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
141       _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
142                           rhs_sums_offset_v);
143     }
144 
145     for (int row = params.start_row; row <= params.last_row;
146          row += kAvx8bitBlockSize) {
147       int channel =
148           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
149       int multiplier_channel =
150           (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
151       const int residual_rows =
152           std::min(params.dst_rows - row, kAvx8bitBlockSize);
153       const int residual_cols =
154           std::min(params.dst_cols - col, kAvx8bitBlockSize);
155 
156       const __m256i splitter_idx = _mm256_loadu_si256(
157           reinterpret_cast<__m256i const*>(splitter_idx_data));
158 
159       __m256i accum_data_v0;
160       __m256i accum_data_v1;
161       __m256i accum_data_v2;
162       __m256i accum_data_v3;
163       __m256i accum_data_v4;
164       __m256i accum_data_v5;
165       __m256i accum_data_v6;
166       __m256i accum_data_v7;
167 
168       // initial_accum_data will be the initialize of each of the
169       // accum_data_* accumulator registers. We compute into it terms that are
170       // identical across columns.
171       __m256i initial_accum_data = _mm256_set1_epi32(params.prod_zp_depth);
172 
173       // In the channels-are-rows case, we can load bias here.
174       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
175           !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
176         initial_accum_data = _mm256_add_epi32(
177             initial_accum_data,
178             _mm256_loadu_si256(
179                 reinterpret_cast<const __m256i*>(params.bias + row)));
180       }
181 
182       // Adjustments common across columns.
183       const std::int32_t rhs_zero_point = params.rhs_zero_point;
184       if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
185         const __m256i lhs_sums_offset = _mm256_mullo_epi32(
186             _mm256_set1_epi32(rhs_zero_point),
187             _mm256_loadu_si256(
188                 reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
189         initial_accum_data =
190             _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
191       }
192 
193       // Adjustments differing across columns.
194       if (has_rhs_sums_offsets) {
195         accum_data_v0 = _mm256_sub_epi32(
196             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
197         accum_data_v1 = _mm256_sub_epi32(
198             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
199         accum_data_v2 = _mm256_sub_epi32(
200             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
201         accum_data_v3 = _mm256_sub_epi32(
202             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
203         accum_data_v4 = _mm256_sub_epi32(
204             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
205         accum_data_v5 = _mm256_sub_epi32(
206             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
207         accum_data_v6 = _mm256_sub_epi32(
208             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
209         accum_data_v7 = _mm256_sub_epi32(
210             initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
211       } else {
212         accum_data_v0 = initial_accum_data;
213         accum_data_v1 = initial_accum_data;
214         accum_data_v2 = initial_accum_data;
215         accum_data_v3 = initial_accum_data;
216         accum_data_v4 = initial_accum_data;
217         accum_data_v5 = initial_accum_data;
218         accum_data_v6 = initial_accum_data;
219         accum_data_v7 = initial_accum_data;
220       }
221 
222       // Finally, in the channels-are-columns case, load bias data here.
223       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
224           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
225         const __m256i bias_data = _mm256_loadu_si256(
226             reinterpret_cast<const __m256i*>(params.bias + col));
227         accum_data_v0 = _mm256_add_epi32(
228             accum_data_v0,
229             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(0)));
230         accum_data_v1 = _mm256_add_epi32(
231             accum_data_v1,
232             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(1)));
233         accum_data_v2 = _mm256_add_epi32(
234             accum_data_v2,
235             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(2)));
236         accum_data_v3 = _mm256_add_epi32(
237             accum_data_v3,
238             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(3)));
239         accum_data_v4 = _mm256_add_epi32(
240             accum_data_v4,
241             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(4)));
242         accum_data_v5 = _mm256_add_epi32(
243             accum_data_v5,
244             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(5)));
245         accum_data_v6 = _mm256_add_epi32(
246             accum_data_v6,
247             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(6)));
248         accum_data_v7 = _mm256_add_epi32(
249             accum_data_v7,
250             _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(7)));
251       }
252 
253       const std::int8_t* lhs_ptr = lhs_col_ptr;
254       const std::int8_t* rhs_ptr = rhs_col_ptr;
255       for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
256         const __m256i lhs_data =
257             _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
258         const __m256i rhs_data_8bit =
259             _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
260 
261         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
262         std::int32_t rhs_data[16];
263         const __m128i rhs_data_bottom_lane =
264             _mm256_castsi256_si128(rhs_data_8bit);
265         const __m128i rhs_data_top_lane =
266             _mm256_extracti128_si256(rhs_data_8bit, 1);
267         const __m256i rhs_16_bit_dup_low =
268             _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
269         const __m256i rhs_16_bit_dup_high =
270             _mm256_cvtepi8_epi16(rhs_data_top_lane);
271         // Now that we have cast the RHS data, we store it so that each value
272         // can be separately loaded in the accumulation loop.
273         _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
274                             rhs_16_bit_dup_low);
275         _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
276                             rhs_16_bit_dup_high);
277 
278         const __m256i lhs_data_split =
279             _mm256_shuffle_epi8(lhs_data, splitter_idx);
280         const __m256i lhs_data_split_expand_bottom =
281             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
282         const __m256i lhs_data_split_expand_top =
283             _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
284 
285         // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
286         const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
287             lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
288         // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
289         const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
290             lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
291 
292         __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
293             rhs_data));  // Load [0 1 2 3 4 5 6 7]
294         __m256i rhs1 = _mm256_lddqu_si256(
295             reinterpret_cast<const __m256i*>(rhs_data + 8));  // Load [8 - 15]
296         __m256i rhs0_3 =
297             _mm256_permute2f128_si256(rhs0, rhs0, 0);  // [0 1 2 3 0 1 2 3]
298         __m256i rhs4_7 =
299             _mm256_permute2f128_si256(rhs0, rhs0, 0x11);  // [4 5 6 7 4 5 6 7]
300         __m256i rhs8_11 =
301             _mm256_permute2f128_si256(rhs1, rhs1, 0);  // [8 9 10 11 8 9 10 11]
302         __m256i rhs12_15 =
303             _mm256_permute2f128_si256(rhs1, rhs1, 17);  // [12 - 15, 12 - 15]
304 
305         auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
306                                   __m256i& accum) {
307           accum = _mm256_add_epi32(
308               accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_dup_lo));
309           accum = _mm256_add_epi32(
310               accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_dup_hi));
311         };
312         __m256i tmp0, tmp1, tmp2, tmp3;
313         tmp0 = _mm256_shuffle_epi32(rhs0_3, 0);
314         tmp1 = _mm256_shuffle_epi32(rhs0_3, 0x55);
315         process_column(tmp0, tmp1, accum_data_v0);
316         tmp2 = _mm256_shuffle_epi32(rhs0_3, 0xaa);
317         tmp3 = _mm256_shuffle_epi32(rhs0_3, 0xff);
318         process_column(tmp2, tmp3, accum_data_v1);
319 
320         tmp0 = _mm256_shuffle_epi32(rhs4_7, 0);
321         tmp1 = _mm256_shuffle_epi32(rhs4_7, 0x55);
322         process_column(tmp0, tmp1, accum_data_v2);
323         tmp2 = _mm256_shuffle_epi32(rhs4_7, 0xaa);
324         tmp3 = _mm256_shuffle_epi32(rhs4_7, 0xff);
325         process_column(tmp2, tmp3, accum_data_v3);
326 
327         tmp0 = _mm256_shuffle_epi32(rhs8_11, 0);
328         tmp1 = _mm256_shuffle_epi32(rhs8_11, 0x55);
329         process_column(tmp0, tmp1, accum_data_v4);
330         tmp2 = _mm256_shuffle_epi32(rhs8_11, 0xaa);
331         tmp3 = _mm256_shuffle_epi32(rhs8_11, 0xff);
332         process_column(tmp2, tmp3, accum_data_v5);
333 
334         tmp0 = _mm256_shuffle_epi32(rhs12_15, 0);
335         tmp1 = _mm256_shuffle_epi32(rhs12_15, 0x55);
336         process_column(tmp0, tmp1, accum_data_v6);
337         tmp2 = _mm256_shuffle_epi32(rhs12_15, 0xaa);
338         tmp3 = _mm256_shuffle_epi32(rhs12_15, 0xff);
339         process_column(tmp2, tmp3, accum_data_v7);
340 
341         lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
342         rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
343       }
344 
345       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
346         __m256i m_vector;
347         __m256i e_vector;
348         // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
349         m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
350             params.multiplier_fixedpoint + multiplier_channel));
351         e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
352             params.multiplier_exponent + multiplier_channel));
353 
354         const __m256i m_64bit_low =
355             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
356         const __m256i m_64bit_high =
357             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
358 
359         const __m256i zero_vector = _mm256_setzero_si256();
360         const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
361         const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
362         const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
363         const __m256i final_right_shift = _mm256_set1_epi32(31);
364         const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(
365             _mm256_extracti128_si256(final_right_shift, 0));
366         const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(
367             _mm256_extracti128_si256(final_right_shift, 1));
368         const __m256i convert_to_unsigned_64 =
369             _mm256_set1_epi64x(0x8000000000000000);
370 
371         __m256i post_scaling_offset = _mm256_setzero_si256();
372         // A "half" added for rounding prior to truncation of 64-bit value.
373         const __m256i offset_vector = _mm256_add_epi64(
374             _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
375             convert_to_unsigned_64);
376 
377         if (params.dst_zero_point) {
378           post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
379         }
380 
381         const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
382 
383         // We cannot do
384         //
385         // scaled_v_low =
386         //     _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
387         // scaled_v_high =
388         //     _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
389         //
390         // since this instruction is not in AVX2. Instead we use
391         // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
392         // offsets before (convert_to_unsigned_64) and after
393         // (convert_to_signed_halved).
394         //
395         // The overall process is, for 64-bit scaled accumulator:
396         // unsigned_accum = signed_accum + 1 << 63;
397         // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
398         // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
399 
400         // There are various ways to repack the results, in the absence of
401         // _mm256_cvtepi64_epi32() or anything like it.
402         // A.
403         // accum_data_v[j] =
404         //     _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
405         //                      _mm256_extract_epi32(scaled_v_high, 4),
406         //                      _mm256_extract_epi32(scaled_v_high, 2),
407         //                      _mm256_extract_epi32(scaled_v_high, 0),
408         //                      _mm256_extract_epi32(scaled_v_low, 6),
409         //                      _mm256_extract_epi32(scaled_v_low, 4),
410         //                      _mm256_extract_epi32(scaled_v_low, 2),
411         //                      _mm256_extract_epi32(scaled_v_low, 0));
412         // B.
413         // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
414         // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
415         // accum_data_v[j] =
416         //     _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
417         //                       _mm256_extract_epi64(scaled_v_high, 0),
418         //                       _mm256_extract_epi64(scaled_v_low, 2),
419         //                       _mm256_extract_epi64(scaled_v_low, 0));
420         // C.
421         // scaled_v_low =
422         //     _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
423         // scaled_v_high =
424         //     _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
425         // accum_data_v[j] =
426         //     _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
427         //
428         // However, we choose the following because it uses two lighter
429         // instructions. The permutation does have a longer latency, but this
430         // loop can be unrolled.
431         // D.
432         // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
433         // __m256i results =
434         //     _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
435         // results = _mm256_permutevar8x32_epi32(results, repack_perm);
436         // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset);
437 
438         // This multiplier code is complex and expensive enough on x86, that
439         // we prefer to implement the channels-are-columns case by transposing
440         // around it, rather than duplicate it (which would also require
441         // duplicating the above code computing the multiplier constants).
442         // This is one instance where channels-are-columns has lower performance
443         // than channels-are-rows.
444         const bool transpose_around_multiplier =
445             (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
446             (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
447         if (transpose_around_multiplier) {
448           // Transpose the 8x8 accumulators block. Will be un-transposed below
449           // after the multplier implementation.
450           intrin_utils::mm256_transpose8x8_epi32<path>(
451               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
452               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
453         }
454 
455         auto rounding_right_shift = [=](__m256i& results,
456                                         const __m256i& exponent) {
457           // Construct the "nudge" value for each lane if the exponent is
458           // greater than 0. Otherwise, the nudge is 0.
459           const __m256i zeros = _mm256_setzero_si256();
460           const __m256i mask_rightshift_gtz =
461               _mm256_cmpgt_epi32(exponent, zeros);
462           const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
463               _mm256_set1_epi32(1),
464               _mm256_sub_epi32(exponent, _mm256_set1_epi32(1)));
465           __m256i nudge = intrin_utils::mm256_blendv_epi32(
466               zeros, one_shift_exp_minus1, mask_rightshift_gtz);
467           // Calculate the shifted sum (results + nudge) >> exp.
468           const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
469           const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent);
470 
471           // Identify overflow in each lane and create mask.
472           const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
473               _mm256_set1_epi32(1),
474               _mm256_sub_epi32(_mm256_set1_epi32(31), exponent));
475           const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
476               results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
477           // Fill results with either (results + nudge) >> exponent or
478           // 1 << (31 - exp) in the case of overflow.
479           results = intrin_utils::mm256_blendv_epi32(
480               shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
481         };
482 
483         auto apply_multiplier = [=](__m256i& accum) {
484           __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
485           // Apply the fixed-point part of the multiplier.
486           __m256i scaled_v_low = _mm256_mul_epi32(
487               _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
488               m_64bit_low);
489           __m256i scaled_v_high = _mm256_mul_epi32(
490               _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
491               m_64bit_high);
492 
493           scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
494           scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
495 
496           scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
497           scaled_v_high =
498               _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
499 
500           scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
501           __m256i results =
502               _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
503           results = _mm256_permutevar8x32_epi32(results, repack_perm);
504           // Now do a Rounding Right Shift.
505           rounding_right_shift(results, right_shift);
506           accum = _mm256_add_epi32(results, post_scaling_offset);
507         };
508         apply_multiplier(accum_data_v0);
509         apply_multiplier(accum_data_v1);
510         apply_multiplier(accum_data_v2);
511         apply_multiplier(accum_data_v3);
512         apply_multiplier(accum_data_v4);
513         apply_multiplier(accum_data_v5);
514         apply_multiplier(accum_data_v6);
515         apply_multiplier(accum_data_v7);
516         // See above comment: here we transpose again to undo the transposition
517         // of the 8x8 block of accumulators used to implement the
518         // channels-are-columns case.
519         if (transpose_around_multiplier) {
520           intrin_utils::mm256_transpose8x8_epi32<path>(
521               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
522               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
523         }
524       }
525       const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
526       const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
527       const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
528                                     (residual_cols == kAvx8bitBlockSize);
529 
530       __m256i accum_data_v[kAvx8bitBlockSize];
531       if (!store_full_block) {
532         accum_data_v[0] = accum_data_v0;
533         accum_data_v[1] = accum_data_v1;
534         accum_data_v[2] = accum_data_v2;
535         accum_data_v[3] = accum_data_v3;
536         accum_data_v[4] = accum_data_v4;
537         accum_data_v[5] = accum_data_v5;
538         accum_data_v[6] = accum_data_v6;
539         accum_data_v[7] = accum_data_v7;
540       }
541 
542       if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
543         std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
544         if (store_full_block) {
545           accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
546           accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
547           accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
548           accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
549           accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
550           accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
551           accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
552           accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
553           accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
554           accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
555           accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
556           accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
557           accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
558           accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
559           accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
560           accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
561           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
562               &tmp_ptr[0 * dst_stride], accum_data_v0);
563           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
564               &tmp_ptr[1 * dst_stride], accum_data_v1);
565           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
566               &tmp_ptr[2 * dst_stride], accum_data_v2);
567           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
568               &tmp_ptr[3 * dst_stride], accum_data_v3);
569           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
570               &tmp_ptr[4 * dst_stride], accum_data_v4);
571           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
572               &tmp_ptr[5 * dst_stride], accum_data_v5);
573           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
574               &tmp_ptr[6 * dst_stride], accum_data_v6);
575           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
576               &tmp_ptr[7 * dst_stride], accum_data_v7);
577         } else {
578           for (int j = 0; j < residual_cols; ++j) {
579             __m256i result = accum_data_v[j];
580             result = _mm256_min_epi32(result, clamp_max_v);
581             result = _mm256_max_epi32(result, clamp_min_v);
582             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
583                 tmp_ptr, residual_rows, result);
584             tmp_ptr += dst_stride;
585           }
586         }
587         dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
588                                      kAvx8bitBlockSize);
589       } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
590         std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
591         if (store_full_block) {
592           accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
593           accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
594           accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
595           accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
596           accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
597           accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
598           accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
599           accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
600           accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
601           accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
602           accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
603           accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
604           accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
605           accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
606           accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
607           accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
608           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
609                                                          accum_data_v0);
610           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
611                                                          accum_data_v1);
612           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
613               &tmp_ptr[2 * dst_stride], accum_data_v2);
614           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
615               &tmp_ptr[3 * dst_stride], accum_data_v3);
616           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
617               &tmp_ptr[4 * dst_stride], accum_data_v4);
618           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
619               &tmp_ptr[5 * dst_stride], accum_data_v5);
620           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
621               &tmp_ptr[6 * dst_stride], accum_data_v6);
622           intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
623               &tmp_ptr[7 * dst_stride], accum_data_v7);
624         } else {
625           for (int j = 0; j < residual_cols; ++j) {
626             __m256i result = accum_data_v[j];
627             result = _mm256_min_epi32(result, clamp_max_v);
628             result = _mm256_max_epi32(result, clamp_min_v);
629             intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
630                 tmp_ptr, residual_rows, result);
631             tmp_ptr += dst_stride;
632           }
633         }
634         dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
635                                      kAvx8bitBlockSize);
636       } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
637         std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
638         if (store_full_block) {
639           accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
640           accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
641           accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
642           accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
643           accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
644           accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
645           accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
646           accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
647           accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
648           accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
649           accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
650           accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
651           accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
652           accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
653           accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
654           accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
655           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
656                                                           accum_data_v0);
657           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
658                                                           accum_data_v1);
659           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
660               &tmp_ptr[2 * dst_stride], accum_data_v2);
661           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
662               &tmp_ptr[3 * dst_stride], accum_data_v3);
663           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
664               &tmp_ptr[4 * dst_stride], accum_data_v4);
665           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
666               &tmp_ptr[5 * dst_stride], accum_data_v5);
667           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
668               &tmp_ptr[6 * dst_stride], accum_data_v6);
669           intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
670               &tmp_ptr[7 * dst_stride], accum_data_v7);
671         } else {
672           for (int j = 0; j < residual_cols; ++j) {
673             __m256i result = accum_data_v[j];
674             result = _mm256_min_epi32(result, clamp_max_v);
675             result = _mm256_max_epi32(result, clamp_min_v);
676             intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
677                 tmp_ptr, residual_rows, result);
678             tmp_ptr += dst_stride;
679           }
680         }
681         dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
682                                      kAvx8bitBlockSize);
683       } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
684         if (store_full_block) {
685           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
686           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
687           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
688                                                  accum_data_v1);
689           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
690                                                  accum_data_v2);
691           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
692                                                  accum_data_v3);
693           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
694                                                  accum_data_v4);
695           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
696                                                  accum_data_v5);
697           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
698                                                  accum_data_v6);
699           intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
700                                                  accum_data_v7);
701         } else {
702           std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
703           for (int j = 0; j < residual_cols; ++j) {
704             intrin_utils::mm256_n_storeu_epi32<path>(
705                 dst_block_ptr, residual_rows, accum_data_v[j]);
706             dst_block_ptr += dst_stride;
707           }
708         }
709         dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
710                                      kAvx8bitBlockSize);
711       } else {
712         RUY_DCHECK(false);
713       }
714 
715       lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
716     }  // End row-block loop.
717 
718     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
719                                      kAvx8bitBlockSize * params.dst_stride);
720     rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
721   }  // End col-block loop.
722 }  // NOLINT(readability/fn_size)
723 
724 void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
725   Kernel8bitAvx2Impl<Path::kAvx2Fma>(params);
726 }
727 
728 template <Path path>
729 void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
730   profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV");
731 
732   RUY_DCHECK_EQ(params.dst_cols, 1);
733   RUY_DCHECK_EQ(params.last_col, 0);
734   RUY_DCHECK_EQ(params.start_col, 0);
735 
736   const std::int8_t splitter_idx_data[32] = {
737       0, 1, 4, 5, 8,  9,  12, 13,  //
738       2, 3, 6, 7, 10, 11, 14, 15,  //
739       0, 1, 4, 5, 8,  9,  12, 13,  //
740       2, 3, 6, 7, 10, 11, 14, 15   //
741   };
742 
743   int bias_ptr_block_increment =
744       params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
745 
746   const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
747   void* dst_col_ptr = params.dst_base_ptr;
748   const std::int32_t* bias_col_ptr = params.bias;
749   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
750     bias_col_ptr += params.start_row;
751   }
752 
753   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
754   void* dst_ptr = dst_col_ptr;
755   const std::int32_t* bias_ptr = bias_col_ptr;
756 
757   const std::int32_t lhs_zero_point = params.lhs_zero_point;
758   const bool has_rhs_sums_offsets =
759       (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
760   std::int32_t rhs_sums_offsets[8];
761   if (has_rhs_sums_offsets) {
762     const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
763         _mm256_set1_epi32(lhs_zero_point),
764         _mm256_loadu_si256(
765             reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
766     _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
767                         rhs_sums_offset_v);
768   }
769 
770   for (int row = params.start_row; row <= params.last_row;
771        row += kAvx8bitBlockSize) {
772     const int residual_rows =
773         std::min(params.dst_rows - row, kAvx8bitBlockSize);
774 
775     const __m256i splitter_idx =
776         _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
777 
778     __m256i accum_data_v0;
779 
780     // Initialize with bias.
781     __m256i initial_accum_data =
782         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
783     bias_ptr += bias_ptr_block_increment;
784 
785     // Adjustments common across columns.
786     const std::int32_t rhs_zero_point = params.rhs_zero_point;
787     if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
788       const __m256i lhs_sums_offset = _mm256_mullo_epi32(
789           _mm256_set1_epi32(rhs_zero_point),
790           _mm256_loadu_si256(
791               reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
792       initial_accum_data =
793           _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
794     }
795     const std::int32_t prod_zp_depth = params.prod_zp_depth;
796     if (prod_zp_depth) {
797       initial_accum_data = _mm256_add_epi32(initial_accum_data,
798                                             _mm256_set1_epi32(prod_zp_depth));
799     }
800 
801     // Adjustments differing across columns.
802     if (has_rhs_sums_offsets) {
803       accum_data_v0 = _mm256_sub_epi32(initial_accum_data,
804                                        _mm256_set1_epi32(rhs_sums_offsets[0]));
805     } else {
806       accum_data_v0 = initial_accum_data;
807     }
808 
809     const std::int8_t* lhs_ptr = lhs_col_ptr;
810     const std::int8_t* rhs_ptr = rhs_col_ptr;
811     for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
812       const __m256i lhs_data =
813           _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
814       const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
815 
816       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
817       // For simplicity we load 4x the data that we need and process twice the
818       // data  that we need  and store only the data we need.
819       std::int32_t rhs_data[2];
820       const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
821       // Now that we have cast the RHS data, we store it so that each value
822       // can be separately loaded in the accumulation loop.
823       _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
824 
825       // NOTE: There may be opportunities for permuting the data in the packing
826       // code instead of here.
827       const __m256i lhs_data_split =
828           _mm256_shuffle_epi8(lhs_data, splitter_idx);
829       const __m256i lhs_data_split_expand_bottom =
830           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
831       const __m256i lhs_data_split_expand_top =
832           _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
833 
834       // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
835       const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
836           lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
837       // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
838       const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
839           lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
840       // Accumulate for column 0.
841       const std::int32_t low_rhs_value = rhs_data[0];
842       const std::int32_t high_rhs_value = rhs_data[1];
843 
844       const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
845       const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
846 
847       accum_data_v0 = _mm256_add_epi32(
848           accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
849       accum_data_v0 = _mm256_add_epi32(
850           accum_data_v0,
851           _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
852 
853       lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
854       rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
855     }
856 
857     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
858       __m256i m_vector;
859       __m256i e_vector;
860       // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
861       int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
862       m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
863           params.multiplier_fixedpoint + channel));
864       e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
865           params.multiplier_exponent + channel));
866 
867       const __m256i m_64bit_low =
868           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
869       const __m256i m_64bit_high =
870           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
871 
872       const __m256i zero_vector = _mm256_setzero_si256();
873       const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
874       const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
875       const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
876       const __m256i final_right_shift = _mm256_set1_epi32(31);
877       const __m256i final_right_shift_low =
878           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
879       const __m256i final_right_shift_high =
880           _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
881       const __m256i convert_to_unsigned_64 =
882           _mm256_set1_epi64x(0x8000000000000000);
883 
884       __m256i post_scaling_offset = _mm256_setzero_si256();
885       // A "half" added for rounding prior to truncation of 64-bit value.
886       const __m256i offset_vector = _mm256_add_epi64(
887           _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
888           convert_to_unsigned_64);
889 
890       if (params.dst_zero_point) {
891         post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
892       }
893 
894       const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
895 
896       // See GEMM version for details of this process.
897       {
898         __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
899         // Apply the fixed-point part of the multiplier.
900         __m256i scaled_v_low = _mm256_mul_epi32(
901             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
902             m_64bit_low);
903         __m256i scaled_v_high = _mm256_mul_epi32(
904             _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
905             m_64bit_high);
906 
907         scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
908         scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
909 
910         scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
911         scaled_v_high =
912             _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
913 
914         scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
915         __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
916         results = _mm256_permutevar8x32_epi32(results, repack_perm);
917 
918         // Now do a Rounding Right Shift.
919         // First, construct the nudge value for each lane.
920         const __m256i zeros = _mm256_setzero_si256();
921         const __m256i mask_rightshift_gtz =
922             _mm256_cmpgt_epi32(right_shift, zeros);
923         const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
924             _mm256_set1_epi32(1),
925             _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
926         __m256i nudge = intrin_utils::mm256_blendv_epi32(
927             zeros, one_shift_exp_minus1, mask_rightshift_gtz);
928         // Calculate the shifted sum (results + nudge) >> exp.
929         const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
930         const __m256i shifted_sum =
931             _mm256_srav_epi32(r_plus_nudge, right_shift);
932 
933         // Identify overflow in each lane and create mask.
934         const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
935             _mm256_set1_epi32(1),
936             _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift));
937         const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
938             results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
939         // Fill results with either (results + nudge) >> exponent or
940         // 1 << (31 - exp) in the case of overflow.
941         results = intrin_utils::mm256_blendv_epi32(
942             shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
943 
944         accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset);
945       }
946     }
947     const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
948     const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
949 
950     if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
951       std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
952       __m256i result = accum_data_v0;
953       result = _mm256_min_epi32(result, clamp_max_v);
954       result = _mm256_max_epi32(result, clamp_min_v);
955       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
956                                                        result);
957       dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
958                                    kAvx8bitBlockSize);
959     } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
960       std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
961       __m256i result = accum_data_v0;
962       result = _mm256_min_epi32(result, clamp_max_v);
963       result = _mm256_max_epi32(result, clamp_min_v);
964       intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
965                                                        result);
966       dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
967                                    kAvx8bitBlockSize);
968     } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
969       std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
970       __m256i result = accum_data_v0;
971       result = _mm256_min_epi32(result, clamp_max_v);
972       result = _mm256_max_epi32(result, clamp_min_v);
973       intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
974                                                         result);
975       dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
976                                    kAvx8bitBlockSize);
977     } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
978       std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
979       intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
980                                                accum_data_v0);
981       dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
982                                    kAvx8bitBlockSize);
983     } else {
984       RUY_DCHECK(false);
985     }
986 
987     lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
988   }  // End row-block loop.
989 
990   dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
991                                    kAvx8bitBlockSize * params.dst_stride);
992   rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
993 }  // NOLINT(readability/fn_size)
994 
995 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
996   Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params);
997 }
998 
999 void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
1000   profiler::ScopeLabel label("Kernel kAvx2Fma float");
1001   KernelFloatAvxCommon<Path::kAvx2Fma>(params);
1002 }
1003 
1004 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
1005   profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV");
1006   KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params);
1007 }
1008 
1009 #endif  //  RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
1010 
1011 }  // namespace ruy
1012