• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 
19 #include "ruy/check_macros.h"
20 #include "ruy/kernel_x86.h"
21 #include "ruy/opt_set.h"
22 #include "ruy/platform.h"
23 #include "ruy/profiler/instrumentation.h"
24 
25 #if RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
26 #include <immintrin.h>  // IWYU pragma: keep
27 #endif
28 
29 namespace ruy {
30 
31 #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
32 
Kernel8bitAvx512(const KernelParams8bit<16,16> &)33 void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) {
34   // CPU-ID-based checks should disable the path that would reach this point.
35   RUY_DCHECK(false);
36 }
37 
Kernel8bitAvx512SingleCol(const KernelParams8bit<16,16> &)38 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) {
39   // CPU-ID-based checks should disable the path that would reach this point.
40   RUY_DCHECK(false);
41 }
42 
KernelFloatAvx512(const KernelParamsFloat<16,16> &)43 void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) {
44   // CPU-ID-based checks should disable the path that would reach this point.
45   RUY_DCHECK(false);
46 }
47 
KernelFloatAvx512SingleCol(const KernelParamsFloat<16,16> &)48 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
49   // CPU-ID-based checks should disable the path that would reach this point.
50   RUY_DCHECK(false);
51 }
52 
53 #else  // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
54 
55 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
56   profiler::ScopeLabel label("Kernel kAvx512 8-bit");
57 
58   std::int32_t dst_stride = 0;
59   if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
60       (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
61     dst_stride = params.dst_stride;
62   } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
63     dst_stride = params.dst_stride / sizeof(std::int16_t);
64   } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
65     dst_stride = params.dst_stride / sizeof(std::int32_t);
66   } else {
67     RUY_DCHECK(false);
68   }
69 
70   const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
71   void* dst_col_ptr = params.dst_base_ptr;
72 
73   for (int col = params.start_col; col <= params.last_col; col += 16) {
74     const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
75     void* dst_ptr = dst_col_ptr;
76 
77     const std::int32_t lhs_zero_point = params.lhs_zero_point;
78     const bool has_rhs_sums_offsets =
79         (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
80     std::int32_t rhs_sums_offsets[16];
81     if (has_rhs_sums_offsets) {
82       const __m512i rhs_sums_offset_v =
83           _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
84                              _mm512_loadu_si512(&params.rhs_sums[col]));
85       _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
86                           rhs_sums_offset_v);
87     }
88 
89     for (int row = params.start_row; row <= params.last_row; row += 16) {
90       int channel =
91           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
92       int multiplier_channel =
93           (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
94 
95       const int residual_rows = std::min(params.dst_rows - row, 16);
96       const int residual_cols = std::min(params.dst_cols - col, 16);
97 
98       __m512i accum_data_v0;
99       __m512i accum_data_v1;
100       __m512i accum_data_v2;
101       __m512i accum_data_v3;
102       __m512i accum_data_v4;
103       __m512i accum_data_v5;
104       __m512i accum_data_v6;
105       __m512i accum_data_v7;
106       __m512i accum_data_v8;
107       __m512i accum_data_v9;
108       __m512i accum_data_va;
109       __m512i accum_data_vb;
110       __m512i accum_data_vc;
111       __m512i accum_data_vd;
112       __m512i accum_data_ve;
113       __m512i accum_data_vf;
114 
115       const __mmask16 row_mask =
116           (static_cast<std::uint32_t>(1) << residual_rows) - 1;
117 
118       // initial_accum_data will be the initialize of each of the
119       // accum_data_* accumulator registers. We compute into it terms that are
120       // identical across columns.
121       __m512i initial_accum_data = _mm512_set1_epi32(params.prod_zp_depth);
122 
123       // In the channels-are-rows case, we can load bias here.
124       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
125           !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
126         initial_accum_data = _mm512_add_epi32(
127             initial_accum_data,
128             _mm512_loadu_si512(
129                 reinterpret_cast<const __m512i*>(params.bias + row)));
130       }
131 
132       const std::int32_t rhs_zero_point = params.rhs_zero_point;
133       if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
134         const __m512i lhs_sums_offset =
135             _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
136                                _mm512_loadu_si512(&params.lhs_sums[row]));
137         initial_accum_data =
138             _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
139       }
140 
141       // Adjustments differing across columns.
142       if (has_rhs_sums_offsets) {
143         accum_data_v0 = _mm512_sub_epi32(
144             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0]));
145         accum_data_v1 = _mm512_sub_epi32(
146             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1]));
147         accum_data_v2 = _mm512_sub_epi32(
148             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2]));
149         accum_data_v3 = _mm512_sub_epi32(
150             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3]));
151         accum_data_v4 = _mm512_sub_epi32(
152             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4]));
153         accum_data_v5 = _mm512_sub_epi32(
154             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5]));
155         accum_data_v6 = _mm512_sub_epi32(
156             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6]));
157         accum_data_v7 = _mm512_sub_epi32(
158             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7]));
159         accum_data_v8 = _mm512_sub_epi32(
160             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8]));
161         accum_data_v9 = _mm512_sub_epi32(
162             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9]));
163         accum_data_va = _mm512_sub_epi32(
164             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10]));
165         accum_data_vb = _mm512_sub_epi32(
166             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11]));
167         accum_data_vc = _mm512_sub_epi32(
168             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12]));
169         accum_data_vd = _mm512_sub_epi32(
170             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13]));
171         accum_data_ve = _mm512_sub_epi32(
172             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14]));
173         accum_data_vf = _mm512_sub_epi32(
174             initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15]));
175       } else {
176         accum_data_v0 = initial_accum_data;
177         accum_data_v1 = initial_accum_data;
178         accum_data_v2 = initial_accum_data;
179         accum_data_v3 = initial_accum_data;
180         accum_data_v4 = initial_accum_data;
181         accum_data_v5 = initial_accum_data;
182         accum_data_v6 = initial_accum_data;
183         accum_data_v7 = initial_accum_data;
184         accum_data_v8 = initial_accum_data;
185         accum_data_v9 = initial_accum_data;
186         accum_data_va = initial_accum_data;
187         accum_data_vb = initial_accum_data;
188         accum_data_vc = initial_accum_data;
189         accum_data_vd = initial_accum_data;
190         accum_data_ve = initial_accum_data;
191         accum_data_vf = initial_accum_data;
192       }
193 
194       // Finally, in the channels-are-columns case, load bias data here.
195       if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
196           (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
197         const __m512i bias_data = _mm512_loadu_si512(
198             reinterpret_cast<const __m512i*>(params.bias + col));
199         accum_data_v0 = _mm512_add_epi32(
200             accum_data_v0,
201             _mm512_permutexvar_epi32(_mm512_set1_epi32(0), bias_data));
202         accum_data_v1 = _mm512_add_epi32(
203             accum_data_v1,
204             _mm512_permutexvar_epi32(_mm512_set1_epi32(1), bias_data));
205         accum_data_v2 = _mm512_add_epi32(
206             accum_data_v2,
207             _mm512_permutexvar_epi32(_mm512_set1_epi32(2), bias_data));
208         accum_data_v3 = _mm512_add_epi32(
209             accum_data_v3,
210             _mm512_permutexvar_epi32(_mm512_set1_epi32(3), bias_data));
211         accum_data_v4 = _mm512_add_epi32(
212             accum_data_v4,
213             _mm512_permutexvar_epi32(_mm512_set1_epi32(4), bias_data));
214         accum_data_v5 = _mm512_add_epi32(
215             accum_data_v5,
216             _mm512_permutexvar_epi32(_mm512_set1_epi32(5), bias_data));
217         accum_data_v6 = _mm512_add_epi32(
218             accum_data_v6,
219             _mm512_permutexvar_epi32(_mm512_set1_epi32(6), bias_data));
220         accum_data_v7 = _mm512_add_epi32(
221             accum_data_v7,
222             _mm512_permutexvar_epi32(_mm512_set1_epi32(7), bias_data));
223         accum_data_v8 = _mm512_add_epi32(
224             accum_data_v8,
225             _mm512_permutexvar_epi32(_mm512_set1_epi32(8), bias_data));
226         accum_data_v9 = _mm512_add_epi32(
227             accum_data_v9,
228             _mm512_permutexvar_epi32(_mm512_set1_epi32(9), bias_data));
229         accum_data_va = _mm512_add_epi32(
230             accum_data_va,
231             _mm512_permutexvar_epi32(_mm512_set1_epi32(10), bias_data));
232         accum_data_vb = _mm512_add_epi32(
233             accum_data_vb,
234             _mm512_permutexvar_epi32(_mm512_set1_epi32(11), bias_data));
235         accum_data_vc = _mm512_add_epi32(
236             accum_data_vc,
237             _mm512_permutexvar_epi32(_mm512_set1_epi32(12), bias_data));
238         accum_data_vd = _mm512_add_epi32(
239             accum_data_vd,
240             _mm512_permutexvar_epi32(_mm512_set1_epi32(13), bias_data));
241         accum_data_ve = _mm512_add_epi32(
242             accum_data_ve,
243             _mm512_permutexvar_epi32(_mm512_set1_epi32(14), bias_data));
244         accum_data_vf = _mm512_add_epi32(
245             accum_data_vf,
246             _mm512_permutexvar_epi32(_mm512_set1_epi32(15), bias_data));
247       }
248 
249       const std::int8_t* lhs_ptr = lhs_col_ptr;
250       const std::int8_t* rhs_ptr = rhs_col_ptr;
251       for (int d = 0; d < params.depth; d += 4) {
252         const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
253         __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr);
254 
255         // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
256         std::int32_t rhs_data[32];
257         const __m256i rhs_data_bottom_lane =
258             _mm512_castsi512_si256(rhs_data_8bit);
259         const __m256i rhs_data_top_lane =
260             _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
261         const __m512i rhs_16_bit_dup_low =
262             _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
263         const __m512i rhs_16_bit_dup_high =
264             _mm512_cvtepi8_epi16(rhs_data_top_lane);
265         // Now that we have cast the RHS data, we store it so that each value
266         // can be separately loaded in the accumulation loop.
267         _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data),
268                             rhs_16_bit_dup_low);
269         _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16),
270                             rhs_16_bit_dup_high);
271 
272         // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
273         const __m512i lhs_16_bit_low =
274             _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
275         // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
276         const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
277             _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
278 
279         auto process_column = [=](int col, __m512i& accum) {
280           const __m512i rhs_16_bit_dup_low =
281               _mm512_set1_epi32(rhs_data[2 * col]);
282           const __m512i rhs_16_bit_dup_high =
283               _mm512_set1_epi32(rhs_data[2 * col + 1]);
284 
285           accum = _mm512_add_epi32(
286               accum, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
287           accum = _mm512_add_epi32(
288               accum, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
289         };
290         process_column(0, accum_data_v0);
291         process_column(1, accum_data_v1);
292         process_column(2, accum_data_v2);
293         process_column(3, accum_data_v3);
294         process_column(4, accum_data_v4);
295         process_column(5, accum_data_v5);
296         process_column(6, accum_data_v6);
297         process_column(7, accum_data_v7);
298         process_column(8, accum_data_v8);
299         process_column(9, accum_data_v9);
300         process_column(10, accum_data_va);
301         process_column(11, accum_data_vb);
302         process_column(12, accum_data_vc);
303         process_column(13, accum_data_vd);
304         process_column(14, accum_data_ve);
305         process_column(15, accum_data_vf);
306 
307         lhs_ptr += 16 * 4;
308         rhs_ptr += 16 * 4;
309       }
310 
311       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
312         // The non-per-channel case could equivalently be handled in the per-row
313         // or per-column code path. The per-row code path is slightly more
314         // efficient so we handle it there.
315         const bool per_column_multiplier =
316             (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
317             (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
318 
319         __m512i m_vector;
320         __m512i e_vector;
321         // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
322         m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
323             params.multiplier_fixedpoint + multiplier_channel));
324         e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
325             params.multiplier_exponent + multiplier_channel));
326 
327         const __m512i m_64bit_low =
328             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
329         const __m512i m_64bit_high =
330             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
331 
332         const __m512i zero_vector = _mm512_setzero_epi32();
333         const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
334         const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
335         const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
336         const __m512i final_right_shift = _mm512_set1_epi32(31);
337         const __m512i right_shift_low =
338             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
339         const __m512i right_shift_high =
340             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
341         const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
342             _mm512_extracti32x8_epi32(final_right_shift, 0));
343         const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
344             _mm512_extracti32x8_epi32(final_right_shift, 1));
345 
346         // A "half" added for rounding prior to truncation of 64-bit value.
347         const __m512i offset_vector =
348             _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
349 
350         auto rounding_right_shift = [=](__m512i& results,
351                                         const __m512i& exponent) {
352           // Construct the "nudge" value for each lane if the exponent is
353           // greater than 0. Otherwise, the nudge is 0.
354           const __m512i zeros = _mm512_setzero_si512();
355           const auto mask_rightshift_gtz =
356               _mm512_cmpgt_epi64_mask(exponent, zeros);
357           const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
358               _mm512_set1_epi64(1),
359               _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
360           __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
361                                                 one_shift_exp_minus1);
362           // Calculate the shifted sum (results + nudge) >> exp.
363           const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
364           const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
365 
366           // Identify overflow in each lane and create mask.
367           const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
368               _mm512_set1_epi64(1),
369               _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
370           const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
371               results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
372           // Fill results with either (results + nudge) >> exponent or
373           // 1 << (31 - exp) in the case of overflow.
374           results = _mm512_mask_mov_epi64(
375               shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
376         };
377 
378         if (per_column_multiplier) {
379           auto apply_multiplier = [=](__m512i& accum, int col) {
380             __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8);
381             // Apply the fixed-point part of the multiplier.
382             __m512i left_shift_val =
383                 _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
384             __m512i m_64bit_val = _mm512_permutexvar_epi64(
385                 perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
386             __m512i offset_vector_val =
387                 _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector);
388             __m512i final_right_shift_val = _mm512_permutexvar_epi64(
389                 perm_64bit_vals,
390                 col < 8 ? final_right_shift_low : final_right_shift_high);
391             __m512i right_shift_val = _mm512_permutexvar_epi64(
392                 perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high);
393 
394             accum = _mm512_sllv_epi32(accum, left_shift_val);
395             __m512i scaled_v_low = _mm512_mul_epi32(
396                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
397                 m_64bit_val);
398             __m512i scaled_v_high = _mm512_mul_epi32(
399                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
400                 m_64bit_val);
401 
402             scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val);
403             scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val);
404 
405             scaled_v_low =
406                 _mm512_srav_epi64(scaled_v_low, final_right_shift_val);
407             scaled_v_high =
408                 _mm512_srav_epi64(scaled_v_high, final_right_shift_val);
409 
410             rounding_right_shift(scaled_v_low, right_shift_val);
411             rounding_right_shift(scaled_v_high, right_shift_val);
412 
413             accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
414             accum = _mm512_inserti32x8(accum,
415                                        _mm512_cvtepi64_epi32(scaled_v_high), 1);
416           };
417           apply_multiplier(accum_data_v0, 0);
418           apply_multiplier(accum_data_v1, 1);
419           apply_multiplier(accum_data_v2, 2);
420           apply_multiplier(accum_data_v3, 3);
421           apply_multiplier(accum_data_v4, 4);
422           apply_multiplier(accum_data_v5, 5);
423           apply_multiplier(accum_data_v6, 6);
424           apply_multiplier(accum_data_v7, 7);
425           apply_multiplier(accum_data_v8, 8);
426           apply_multiplier(accum_data_v9, 9);
427           apply_multiplier(accum_data_va, 10);
428           apply_multiplier(accum_data_vb, 11);
429           apply_multiplier(accum_data_vc, 12);
430           apply_multiplier(accum_data_vd, 13);
431           apply_multiplier(accum_data_ve, 14);
432           apply_multiplier(accum_data_vf, 15);
433         } else {  // not per-column, so per-row
434           auto apply_multiplier = [=](__m512i& accum) {
435             accum = _mm512_sllv_epi32(accum, left_shift);
436             // Apply the fixed-point part of the multiplier.
437             __m512i scaled_v_low = _mm512_mul_epi32(
438                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
439                 m_64bit_low);
440             __m512i scaled_v_high = _mm512_mul_epi32(
441                 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
442                 m_64bit_high);
443 
444             scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
445             scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
446 
447             scaled_v_low =
448                 _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
449             scaled_v_high =
450                 _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
451 
452             rounding_right_shift(scaled_v_low, right_shift_low);
453             rounding_right_shift(scaled_v_high, right_shift_high);
454             accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
455             accum = _mm512_inserti32x8(accum,
456                                        _mm512_cvtepi64_epi32(scaled_v_high), 1);
457           };
458           apply_multiplier(accum_data_v0);
459           apply_multiplier(accum_data_v1);
460           apply_multiplier(accum_data_v2);
461           apply_multiplier(accum_data_v3);
462           apply_multiplier(accum_data_v4);
463           apply_multiplier(accum_data_v5);
464           apply_multiplier(accum_data_v6);
465           apply_multiplier(accum_data_v7);
466           apply_multiplier(accum_data_v8);
467           apply_multiplier(accum_data_v9);
468           apply_multiplier(accum_data_va);
469           apply_multiplier(accum_data_vb);
470           apply_multiplier(accum_data_vc);
471           apply_multiplier(accum_data_vd);
472           apply_multiplier(accum_data_ve);
473           apply_multiplier(accum_data_vf);
474         }
475 
476         if (params.dst_zero_point != 0) {
477           __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
478           accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
479           accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point);
480           accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point);
481           accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point);
482           accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point);
483           accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point);
484           accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point);
485           accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point);
486           accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point);
487           accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point);
488           accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point);
489           accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point);
490           accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point);
491           accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point);
492           accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point);
493           accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point);
494         }
495       }
496 
497       const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
498       const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
499 
500       const bool store_full_block =
501           (residual_rows == 16) && (residual_cols == 16);
502 
503       __m512i accum_data_v[16];
504 
505       // In most cases we would make this conditional on (!store_full_block) and
506       // unwind the clamp-and-store loop, but the benefit appears small.
507       {
508         accum_data_v[0] = accum_data_v0;
509         accum_data_v[1] = accum_data_v1;
510         accum_data_v[2] = accum_data_v2;
511         accum_data_v[3] = accum_data_v3;
512         accum_data_v[4] = accum_data_v4;
513         accum_data_v[5] = accum_data_v5;
514         accum_data_v[6] = accum_data_v6;
515         accum_data_v[7] = accum_data_v7;
516         accum_data_v[8] = accum_data_v8;
517         accum_data_v[9] = accum_data_v9;
518         accum_data_v[10] = accum_data_va;
519         accum_data_v[11] = accum_data_vb;
520         accum_data_v[12] = accum_data_vc;
521         accum_data_v[13] = accum_data_vd;
522         accum_data_v[14] = accum_data_ve;
523         accum_data_v[15] = accum_data_vf;
524       }
525 
526       if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
527         std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
528         const int block_col_offset = dst_stride;
529         if (store_full_block) {
530           for (int j = 0; j < 16; ++j) {
531             __m512i result = accum_data_v[j];
532             result = _mm512_min_epi32(result, clamp_max_v);
533             result = _mm512_max_epi32(result, clamp_min_v);
534             _mm_storeu_si128(
535                 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
536                 _mm512_cvtepi32_epi8(result));
537           }
538         } else {
539           for (int j = 0; j < residual_cols; ++j) {
540             __m512i result = accum_data_v[j];
541             result = _mm512_min_epi32(result, clamp_max_v);
542             result = _mm512_max_epi32(result, clamp_min_v);
543             _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
544                                  _mm512_cvtepi32_epi8(result));
545           }
546         }
547         dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
548       } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
549         std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
550         const int block_col_offset = dst_stride;
551         if (store_full_block) {
552           for (int j = 0; j < residual_cols; ++j) {
553             __m512i result = accum_data_v[j];
554             result = _mm512_min_epi32(result, clamp_max_v);
555             result = _mm512_max_epi32(result, clamp_min_v);
556             _mm_storeu_si128(
557                 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
558                 _mm512_cvtepi32_epi8(result));
559           }
560         } else {
561           for (int j = 0; j < residual_cols; ++j) {
562             __m512i result = accum_data_v[j];
563             result = _mm512_min_epi32(result, clamp_max_v);
564             result = _mm512_max_epi32(result, clamp_min_v);
565             _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
566                                  _mm512_cvtepi32_epi8(result));
567           }
568         }
569         dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
570       } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
571         std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
572         const int block_col_offset = dst_stride;
573         if (store_full_block) {
574           for (int j = 0; j < 16; ++j) {
575             __m512i result = accum_data_v[j];
576             result = _mm512_min_epi32(result, clamp_max_v);
577             result = _mm512_max_epi32(result, clamp_min_v);
578             _mm256_storeu_si256(
579                 reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset),
580                 _mm512_cvtepi32_epi16(result));
581           }
582         } else {
583           for (int j = 0; j < residual_cols; ++j) {
584             __m512i result = accum_data_v[j];
585             result = _mm512_min_epi32(result, clamp_max_v);
586             result = _mm512_max_epi32(result, clamp_min_v);
587             _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask,
588                                      _mm512_cvtepi32_epi16(result));
589           }
590         }
591         dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
592       } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
593         if (store_full_block) {
594           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
595           for (int j = 0; j < 16; ++j) {
596             _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]);
597           }
598         } else {
599           std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
600           for (int j = 0; j < residual_cols; ++j) {
601             _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask,
602                                      accum_data_v[j]);
603           }
604         }
605         dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
606       } else {
607         RUY_DCHECK(false);
608       }
609 
610       lhs_col_ptr += 16 * params.lhs_stride;
611     }  // End row-block loop.
612 
613     dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
614                                      16 * params.dst_stride);
615     rhs_col_ptr += 16 * params.rhs_stride;
616   }  // End col-block loop.
617 }  // NOLINT(readability/fn_size)
618 
619 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
620   profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV");
621 
622   RUY_DCHECK_EQ(params.dst_cols, 1);
623   RUY_DCHECK_EQ(params.last_col, 0);
624   RUY_DCHECK_EQ(params.start_col, 0);
625 
626   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
627 
628   const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
629   void* dst_col_ptr = params.dst_base_ptr;
630   const std::int32_t* bias_col_ptr = params.bias;
631   if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
632     bias_col_ptr += params.start_row;
633   }
634 
635   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
636   void* dst_ptr = dst_col_ptr;
637   const std::int32_t* bias_ptr = bias_col_ptr;
638 
639   const std::int32_t lhs_zero_point = params.lhs_zero_point;
640   const bool has_rhs_sums_offsets =
641       (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
642   std::int32_t rhs_sums_offsets[16];
643   if (has_rhs_sums_offsets) {
644     const __m512i rhs_sums_offset_v =
645         _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
646                            _mm512_loadu_si512(&params.rhs_sums[0]));
647     _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
648                         rhs_sums_offset_v);
649   }
650 
651   for (int row = params.start_row; row <= params.last_row; row += 16) {
652     const int residual_rows = std::min(params.dst_rows - row, 16);
653 
654     __m512i accum_data_v0;
655 
656     // Initialize with bias.
657     const __mmask16 row_mask =
658         (static_cast<std::uint32_t>(1) << residual_rows) - 1;
659     __m512i initial_accum_data =
660         _mm512_loadu_si512(reinterpret_cast<const __m512i*>(bias_ptr));
661     bias_ptr += bias_ptr_block_increment;
662 
663     const std::int32_t rhs_zero_point = params.rhs_zero_point;
664     if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
665       const __m512i lhs_sums_offset =
666           _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
667                              _mm512_loadu_si512(&params.lhs_sums[row]));
668       initial_accum_data =
669           _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
670     }
671 
672     const std::int32_t prod_zp_depth = params.prod_zp_depth;
673     if (prod_zp_depth != 0) {
674       initial_accum_data = _mm512_add_epi32(initial_accum_data,
675                                             _mm512_set1_epi32(prod_zp_depth));
676     }
677 
678     // Adjustments differing across columns.
679     if (has_rhs_sums_offsets) {
680       accum_data_v0 = _mm512_sub_epi32(initial_accum_data,
681                                        _mm512_set1_epi32(rhs_sums_offsets[0]));
682     } else {
683       accum_data_v0 = initial_accum_data;
684     }
685 
686     const std::int8_t* lhs_ptr = lhs_col_ptr;
687     const std::int8_t* rhs_ptr = rhs_col_ptr;
688     for (int d = 0; d < params.depth; d += 4) {
689       const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
690       const __m128i rhs_data_8bit =
691           _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
692 
693       // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
694       // For simplicity we load 4x the data that we need and process twice the
695       // data  that we need  and store only the data we need.
696       std::int32_t rhs_data[2];
697       const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
698       // Now that we have cast the RHS data, we store it so that each value
699       // can be separately loaded in the accumulation loop.
700       _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
701 
702       // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
703       const __m512i lhs_16_bit_low =
704           _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
705       // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
706       const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
707           _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
708 
709       // Process column 0.
710       __m512i accum_v = accum_data_v0;
711       constexpr int index = 0;
712 
713       const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
714       const __m512i rhs_16_bit_dup_high =
715           _mm512_set1_epi32(rhs_data[index + 1]);
716 
717       accum_v = _mm512_add_epi32(
718           accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
719       accum_v = _mm512_add_epi32(
720           accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
721       accum_data_v0 = accum_v;
722 
723       lhs_ptr += 16 * 4;
724       rhs_ptr += 16 * 4;
725     }
726 
727     if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
728       __m512i m_vector;
729       __m512i e_vector;
730       // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
731       int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
732       m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
733           params.multiplier_fixedpoint + channel));
734       e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
735           params.multiplier_exponent + channel));
736 
737       const __m512i m_64bit_low =
738           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
739       const __m512i m_64bit_high =
740           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
741 
742       const __m512i zero_vector = _mm512_setzero_epi32();
743       const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
744       const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
745       const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
746       const __m512i final_right_shift = _mm512_set1_epi32(31);
747       const __m512i right_shift_low =
748           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
749       const __m512i right_shift_high =
750           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
751       const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
752           _mm512_extracti32x8_epi32(final_right_shift, 0));
753       const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
754           _mm512_extracti32x8_epi32(final_right_shift, 1));
755 
756       // A "half" added for rounding prior to truncation of 64-bit value.
757       const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
758 
759       auto rounding_right_shift = [=](__m512i& results,
760                                       const __m512i& exponent) {
761         // Construct the "nudge" value for each lane if the exponent is
762         // greater than 0. Otherwise, the nudge is 0.
763         const __m512i zeros = _mm512_setzero_si512();
764         const auto mask_rightshift_gtz =
765             _mm512_cmpgt_epi64_mask(exponent, zeros);
766         const __m512i one_shift_exp_minus1 =
767             _mm512_sllv_epi64(_mm512_set1_epi64(1),
768                               _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
769         __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
770                                               one_shift_exp_minus1);
771         // Calculate the shifted sum (results + nudge) >> exp.
772         const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
773         const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
774 
775         // Identify overflow in each lane and create mask.
776         const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
777             _mm512_set1_epi64(1),
778             _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
779         const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
780             results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
781         // Fill results with either (results + nudge) >> exponent or
782         // 1 << (31 - exp) in the case of overflow.
783         results = _mm512_mask_mov_epi64(
784             shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
785       };
786 
787       // Shift and round column 0.
788       accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
789       // Apply the fixed-point part of the multiplier.
790       __m512i scaled_v_low = _mm512_mul_epi32(
791           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)),
792           m_64bit_low);
793       __m512i scaled_v_high = _mm512_mul_epi32(
794           _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
795           m_64bit_high);
796 
797       scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
798       scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
799 
800       scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
801       scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
802 
803       rounding_right_shift(scaled_v_low, right_shift_low);
804       rounding_right_shift(scaled_v_high, right_shift_high);
805 
806       accum_data_v0 =
807           _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
808       accum_data_v0 = _mm512_inserti32x8(
809           accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
810 
811       if (params.dst_zero_point != 0) {
812         __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
813         accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
814       }
815     }
816 
817     const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
818     const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
819 
820     if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
821       std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
822       __m512i result = accum_data_v0;
823       result = _mm512_min_epi32(result, clamp_max_v);
824       result = _mm512_max_epi32(result, clamp_min_v);
825       _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
826       dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
827     } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
828       std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
829       __m512i result = accum_data_v0;
830       result = _mm512_min_epi32(result, clamp_max_v);
831       result = _mm512_max_epi32(result, clamp_min_v);
832       _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
833       dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
834     } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
835       std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
836       __m512i result = accum_data_v0;
837       result = _mm512_min_epi32(result, clamp_max_v);
838       result = _mm512_max_epi32(result, clamp_min_v);
839       _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
840                                _mm512_cvtepi32_epi16(result));
841       dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
842     } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
843       std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
844       _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0);
845       dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
846     } else {
847       RUY_DCHECK(false);
848     }
849 
850     lhs_col_ptr += 16 * params.lhs_stride;
851   }  // End row-block loop.
852 }  // NOLINT(readability/fn_size)
853 
854 void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
855   profiler::ScopeLabel label("Kernel kAvx512 float");
856 
857   // As parameters are defined, we need to scale by sizeof(float).
858   const std::int64_t lhs_stride = params.lhs_stride >> 2;
859   const std::int64_t dst_stride = params.dst_stride >> 2;
860   const std::int64_t rhs_stride = params.rhs_stride >> 2;
861 
862   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
863   const int end_row = std::min(params.dst_rows, params.last_row + 16);
864   const int end_col = std::min(params.dst_cols, params.last_col + 16);
865 
866   const float* adj_rhs_col_ptr =
867       params.rhs_base_ptr - params.start_col * rhs_stride;
868   float* adj_dst_col_ptr =
869       params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
870   const float* adj_lhs_col_ptr =
871       params.lhs_base_ptr - params.start_row * lhs_stride;
872   const float* bias_ptr = params.bias;
873 
874   const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
875   const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
876   const bool channel_dimension_is_col =
877       params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
878 
879   int col = params.start_col;
880   for (; col <= end_col - 16; col += 16) {
881     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
882     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
883 
884     int row = params.start_row;
885     for (; row <= end_row - 16; row += 16) {
886       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
887       float* dst_ptr = dst_col_ptr + row;
888 
889       // Process block in two halves, split by columns.
890 #pragma unroll(1)
891       for (int mmm = 0; mmm < 2; ++mmm) {
892         __m512 accum_data_v0;
893         __m512 accum_data_v1;
894         __m512 accum_data_v2;
895         __m512 accum_data_v3;
896         __m512 accum_data_v4;
897         __m512 accum_data_v5;
898         __m512 accum_data_v6;
899         __m512 accum_data_v7;
900 
901         // Initialize with bias.
902         if (channel_dimension_is_col) {
903           const float* bias_elem_ptr =
904               bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
905           accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
906           accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
907           accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
908           accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
909           accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
910           accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
911           accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
912           accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
913         } else {
914           const __m512 initial_accum_data =
915               _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
916 
917           accum_data_v0 = initial_accum_data;
918           accum_data_v1 = initial_accum_data;
919           accum_data_v2 = initial_accum_data;
920           accum_data_v3 = initial_accum_data;
921           accum_data_v4 = initial_accum_data;
922           accum_data_v5 = initial_accum_data;
923           accum_data_v6 = initial_accum_data;
924           accum_data_v7 = initial_accum_data;
925         }
926 
927         const float* lhs_ptr = lhs_col_ptr;
928         const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
929         for (int d = 0; d < (params.depth - 1); ++d) {
930           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
931           const float* rhs_data = rhs_ptr;
932           lhs_ptr += 16;
933           rhs_ptr += 16;
934 
935           // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
936           // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
937           // so if given an rvalue.
938           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
939                                           accum_data_v0);
940           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
941                                           accum_data_v1);
942           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
943                                           accum_data_v2);
944           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
945                                           accum_data_v3);
946           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
947                                           accum_data_v4);
948           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
949                                           accum_data_v5);
950           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
951                                           accum_data_v6);
952           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
953                                           accum_data_v7);
954         }
955         {  // nested extra blocks lead to measurable speed gains
956           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
957           const float* rhs_data = rhs_ptr;
958           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
959                                           accum_data_v0);
960           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
961                                           accum_data_v1);
962           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
963                                           accum_data_v2);
964           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
965                                           accum_data_v3);
966           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
967                                           accum_data_v4);
968           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
969                                           accum_data_v5);
970           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
971                                           accum_data_v6);
972           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
973                                           accum_data_v7);
974           {
975             float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
976             accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
977             accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
978             _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
979             accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
980             accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
981             _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
982             accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
983             accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
984             _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
985             accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
986             accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
987             _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
988             accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
989             accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
990             _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
991             accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
992             accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
993             _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
994             accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
995             accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
996             _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
997             accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
998             accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
999             _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
1000           }
1001         }
1002       }
1003     }    // End row-block loop.
1004 
1005     // The unrolling within this conditional may be somewhat pointless. It
1006     // depends on the kinds of models.
1007     if (row < end_row) {
1008       const int residual_rows = end_row - row;
1009 
1010       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1011       float* dst_ptr = dst_col_ptr + row;
1012 
1013       const __mmask16 row_mask =
1014           (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1015 
1016       // Process block in two halves, split by columns.
1017       for (int mmm = 0; mmm < 2; ++mmm) {
1018         __m512 accum_data_v0;
1019         __m512 accum_data_v1;
1020         __m512 accum_data_v2;
1021         __m512 accum_data_v3;
1022         __m512 accum_data_v4;
1023         __m512 accum_data_v5;
1024         __m512 accum_data_v6;
1025         __m512 accum_data_v7;
1026 
1027         // Initialize with bias.
1028         if (channel_dimension_is_col) {
1029           const float* bias_elem_ptr =
1030               bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1031           accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
1032           accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
1033           accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
1034           accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
1035           accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
1036           accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
1037           accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
1038           accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
1039         } else {
1040           const __m512 initial_accum_data =
1041               _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1042 
1043           accum_data_v0 = initial_accum_data;
1044           accum_data_v1 = initial_accum_data;
1045           accum_data_v2 = initial_accum_data;
1046           accum_data_v3 = initial_accum_data;
1047           accum_data_v4 = initial_accum_data;
1048           accum_data_v5 = initial_accum_data;
1049           accum_data_v6 = initial_accum_data;
1050           accum_data_v7 = initial_accum_data;
1051         }
1052 
1053         const float* lhs_ptr = lhs_col_ptr;
1054         const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1055         for (int d = 0; d < (params.depth - 1); ++d) {
1056           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1057           const float* rhs_data = rhs_ptr;
1058           lhs_ptr += 16;
1059           rhs_ptr += 16;
1060           // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
1061           // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
1062           // so if given an rvalue.
1063           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1064                                           accum_data_v0);
1065           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1066                                           accum_data_v1);
1067           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1068                                           accum_data_v2);
1069           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1070                                           accum_data_v3);
1071           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1072                                           accum_data_v4);
1073           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1074                                           accum_data_v5);
1075           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1076                                           accum_data_v6);
1077           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1078                                           accum_data_v7);
1079         }
1080         {
1081           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1082           const float* rhs_data = rhs_ptr;
1083           accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1084                                           accum_data_v0);
1085           accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1086                                           accum_data_v1);
1087           accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1088                                           accum_data_v2);
1089           accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1090                                           accum_data_v3);
1091           accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1092                                           accum_data_v4);
1093           accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1094                                           accum_data_v5);
1095           accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1096                                           accum_data_v6);
1097           accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1098                                           accum_data_v7);
1099           {
1100             float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
1101             accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
1102             accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
1103             _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
1104                                   accum_data_v0);
1105             accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
1106             accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
1107             _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
1108                                   accum_data_v1);
1109             accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
1110             accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
1111             _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
1112                                   accum_data_v2);
1113             accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
1114             accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
1115             _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
1116                                   accum_data_v3);
1117             accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
1118             accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
1119             _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
1120                                   accum_data_v4);
1121             accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
1122             accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
1123             _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
1124                                   accum_data_v5);
1125             accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
1126             accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
1127             _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
1128                                   accum_data_v6);
1129             accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
1130             accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
1131             _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
1132                                   accum_data_v7);
1133           }
1134         }
1135       }  // Inner half-block loop.
1136     }    // Residual rows, main col-block loop.
1137   }      // End col-block loop.
1138 
1139   if (col < end_col) {
1140     RUY_DCHECK_GE(end_col - col, 0);
1141     RUY_DCHECK_LT(end_col - col, 16);
1142 
1143     __m512 accum_data_v[8];
1144 
1145     const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
1146     float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
1147 
1148     for (int row = params.start_row; row < end_row; row += 16) {
1149       const int residual_rows = std::min(end_row - row, 16);
1150 
1151       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1152       float* dst_ptr = dst_col_ptr + row;
1153 
1154       const __mmask16 row_mask =
1155           (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1156 
1157       // Process block in two halves, split by columns.
1158       for (int mmm = 0; mmm < 2; ++mmm) {
1159         // Initialize with bias.
1160         if (channel_dimension_is_col) {
1161           const float* bias_elem_ptr =
1162               bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1163           for (int j = 0; j < 8; ++j) {
1164             accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]);
1165           }
1166         } else {
1167           const __m512 initial_accum_data =
1168               _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1169           for (int j = 0; j < 8; ++j) {
1170             accum_data_v[j] = initial_accum_data;
1171           }
1172         }
1173 
1174         const float* lhs_ptr = lhs_col_ptr;
1175         const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1176         for (int d = 0; d < params.depth; ++d) {
1177           const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1178           const float* rhs_data = rhs_ptr;
1179 
1180           for (int j = 0; j < 8; ++j) {
1181             const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
1182             accum_data_v[j] =
1183                 _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
1184           }
1185           lhs_ptr += 16;
1186           rhs_ptr += 16;
1187         }
1188 
1189         const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
1190 
1191         if (residual_rows == 16) {
1192           if (residual_cols == 8) {
1193             for (int j = 0; j < 8; ++j) {
1194               float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1195               accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1196               accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1197               _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1198             }
1199           } else {
1200             for (int j = 0; j < residual_cols; ++j) {
1201               float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1202               accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1203               accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1204               _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1205             }
1206           }
1207         } else {
1208           for (int j = 0; j < residual_cols; ++j) {
1209             float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1210             accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1211             accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1212             _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
1213           }
1214         }
1215       }  // Inner half-block loop.
1216     }    // End row-block loop.
1217   }      // Residual cols.
1218 }
1219 
1220 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
1221   profiler::ScopeLabel label("Kernel kAvx512 float GEMV");
1222 
1223   RUY_DCHECK_EQ(params.dst_cols, 1);
1224   RUY_DCHECK_EQ(params.last_col, 0);
1225   RUY_DCHECK_EQ(params.start_col, 0);
1226 
1227   // As parameters are defined, we need to scale by sizeof(float).
1228   const std::int64_t lhs_stride = params.lhs_stride >> 2;
1229 
1230   int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
1231   const int end_row = std::min(params.dst_rows, params.last_row + 16);
1232 
1233   float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
1234   const float* adj_lhs_col_ptr =
1235       params.lhs_base_ptr - params.start_row * lhs_stride;
1236   const float* bias_col_ptr = params.bias;
1237 
1238   const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
1239   const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
1240 
1241   __m512 accum_data_v;
1242 
1243   const float* rhs_col_ptr = params.rhs_base_ptr;
1244   float* dst_col_ptr = adj_dst_col_ptr;
1245 
1246   int row = params.start_row;
1247   for (; row <= end_row - 16; row += 16) {
1248     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1249     float* dst_ptr = dst_col_ptr + row;
1250     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1251 
1252     // Initialize with bias.
1253     accum_data_v = _mm512_loadu_ps(bias_ptr);
1254 
1255     const float* lhs_ptr = lhs_col_ptr;
1256     const float* rhs_ptr = rhs_col_ptr;
1257     for (int d = 0; d < params.depth; ++d) {
1258       const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1259       const float rhs_data = *rhs_ptr;
1260 
1261       const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1262       accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1263       lhs_ptr += 16;
1264       rhs_ptr += 16;
1265     }
1266 
1267     accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1268     accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1269     _mm512_storeu_ps(dst_ptr, accum_data_v);
1270   }  // End row-block loop.
1271 
1272   if (row < end_row) {
1273     const int residual_rows = end_row - row;
1274     RUY_CHECK_GE(residual_rows, 1);
1275     RUY_CHECK_LT(residual_rows, 16);
1276 
1277     const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1278     float* dst_ptr = dst_col_ptr + row;
1279     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1280 
1281     // Initialize with bias.
1282     const __mmask16 row_mask =
1283         (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1284     accum_data_v = _mm512_loadu_ps(bias_ptr);
1285 
1286     const float* lhs_ptr = lhs_col_ptr;
1287     const float* rhs_ptr = rhs_col_ptr;
1288     for (int d = 0; d < params.depth; ++d) {
1289       const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1290       const float rhs_data = *rhs_ptr;
1291 
1292       const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1293       accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1294       lhs_ptr += 16;
1295       rhs_ptr += 16;
1296     }
1297 
1298     accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1299     accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1300     _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v);
1301   }  // End handling of residual rows.
1302 }
1303 
1304 #endif  //  RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
1305 
1306 }  // namespace ruy
1307