1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cstdint>
18
19 #include "ruy/check_macros.h"
20 #include "ruy/kernel_x86.h"
21 #include "ruy/opt_set.h"
22 #include "ruy/platform.h"
23 #include "ruy/profiler/instrumentation.h"
24
25 #if RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
26 #include <immintrin.h> // IWYU pragma: keep
27 #endif
28
29 namespace ruy {
30
31 #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
32
Kernel8bitAvx512(const KernelParams8bit<16,16> &)33 void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) {
34 // CPU-ID-based checks should disable the path that would reach this point.
35 RUY_DCHECK(false);
36 }
37
Kernel8bitAvx512SingleCol(const KernelParams8bit<16,16> &)38 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) {
39 // CPU-ID-based checks should disable the path that would reach this point.
40 RUY_DCHECK(false);
41 }
42
KernelFloatAvx512(const KernelParamsFloat<16,16> &)43 void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) {
44 // CPU-ID-based checks should disable the path that would reach this point.
45 RUY_DCHECK(false);
46 }
47
KernelFloatAvx512SingleCol(const KernelParamsFloat<16,16> &)48 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
49 // CPU-ID-based checks should disable the path that would reach this point.
50 RUY_DCHECK(false);
51 }
52
53 #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
54
55 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
56 profiler::ScopeLabel label("Kernel kAvx512 8-bit");
57
58 std::int32_t dst_stride = 0;
59 if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
60 (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
61 dst_stride = params.dst_stride;
62 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
63 dst_stride = params.dst_stride / sizeof(std::int16_t);
64 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
65 dst_stride = params.dst_stride / sizeof(std::int32_t);
66 } else {
67 RUY_DCHECK(false);
68 }
69
70 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
71 void* dst_col_ptr = params.dst_base_ptr;
72
73 for (int col = params.start_col; col <= params.last_col; col += 16) {
74 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
75 void* dst_ptr = dst_col_ptr;
76
77 const std::int32_t lhs_zero_point = params.lhs_zero_point;
78 const bool has_rhs_sums_offsets =
79 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
80 std::int32_t rhs_sums_offsets[16];
81 if (has_rhs_sums_offsets) {
82 const __m512i rhs_sums_offset_v =
83 _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
84 _mm512_loadu_si512(¶ms.rhs_sums[col]));
85 _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
86 rhs_sums_offset_v);
87 }
88
89 for (int row = params.start_row; row <= params.last_row; row += 16) {
90 int channel =
91 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
92 int multiplier_channel =
93 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
94
95 const int residual_rows = std::min(params.dst_rows - row, 16);
96 const int residual_cols = std::min(params.dst_cols - col, 16);
97
98 __m512i accum_data_v0;
99 __m512i accum_data_v1;
100 __m512i accum_data_v2;
101 __m512i accum_data_v3;
102 __m512i accum_data_v4;
103 __m512i accum_data_v5;
104 __m512i accum_data_v6;
105 __m512i accum_data_v7;
106 __m512i accum_data_v8;
107 __m512i accum_data_v9;
108 __m512i accum_data_va;
109 __m512i accum_data_vb;
110 __m512i accum_data_vc;
111 __m512i accum_data_vd;
112 __m512i accum_data_ve;
113 __m512i accum_data_vf;
114
115 const __mmask16 row_mask =
116 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
117
118 // initial_accum_data will be the initialize of each of the
119 // accum_data_* accumulator registers. We compute into it terms that are
120 // identical across columns.
121 __m512i initial_accum_data = _mm512_set1_epi32(params.prod_zp_depth);
122
123 // In the channels-are-rows case, we can load bias here.
124 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
125 !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
126 initial_accum_data = _mm512_add_epi32(
127 initial_accum_data,
128 _mm512_loadu_si512(
129 reinterpret_cast<const __m512i*>(params.bias + row)));
130 }
131
132 const std::int32_t rhs_zero_point = params.rhs_zero_point;
133 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
134 const __m512i lhs_sums_offset =
135 _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
136 _mm512_loadu_si512(¶ms.lhs_sums[row]));
137 initial_accum_data =
138 _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
139 }
140
141 // Adjustments differing across columns.
142 if (has_rhs_sums_offsets) {
143 accum_data_v0 = _mm512_sub_epi32(
144 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0]));
145 accum_data_v1 = _mm512_sub_epi32(
146 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1]));
147 accum_data_v2 = _mm512_sub_epi32(
148 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2]));
149 accum_data_v3 = _mm512_sub_epi32(
150 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3]));
151 accum_data_v4 = _mm512_sub_epi32(
152 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4]));
153 accum_data_v5 = _mm512_sub_epi32(
154 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5]));
155 accum_data_v6 = _mm512_sub_epi32(
156 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6]));
157 accum_data_v7 = _mm512_sub_epi32(
158 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7]));
159 accum_data_v8 = _mm512_sub_epi32(
160 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8]));
161 accum_data_v9 = _mm512_sub_epi32(
162 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9]));
163 accum_data_va = _mm512_sub_epi32(
164 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10]));
165 accum_data_vb = _mm512_sub_epi32(
166 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11]));
167 accum_data_vc = _mm512_sub_epi32(
168 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12]));
169 accum_data_vd = _mm512_sub_epi32(
170 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13]));
171 accum_data_ve = _mm512_sub_epi32(
172 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14]));
173 accum_data_vf = _mm512_sub_epi32(
174 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15]));
175 } else {
176 accum_data_v0 = initial_accum_data;
177 accum_data_v1 = initial_accum_data;
178 accum_data_v2 = initial_accum_data;
179 accum_data_v3 = initial_accum_data;
180 accum_data_v4 = initial_accum_data;
181 accum_data_v5 = initial_accum_data;
182 accum_data_v6 = initial_accum_data;
183 accum_data_v7 = initial_accum_data;
184 accum_data_v8 = initial_accum_data;
185 accum_data_v9 = initial_accum_data;
186 accum_data_va = initial_accum_data;
187 accum_data_vb = initial_accum_data;
188 accum_data_vc = initial_accum_data;
189 accum_data_vd = initial_accum_data;
190 accum_data_ve = initial_accum_data;
191 accum_data_vf = initial_accum_data;
192 }
193
194 // Finally, in the channels-are-columns case, load bias data here.
195 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
196 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
197 const __m512i bias_data = _mm512_loadu_si512(
198 reinterpret_cast<const __m512i*>(params.bias + col));
199 accum_data_v0 = _mm512_add_epi32(
200 accum_data_v0,
201 _mm512_permutexvar_epi32(_mm512_set1_epi32(0), bias_data));
202 accum_data_v1 = _mm512_add_epi32(
203 accum_data_v1,
204 _mm512_permutexvar_epi32(_mm512_set1_epi32(1), bias_data));
205 accum_data_v2 = _mm512_add_epi32(
206 accum_data_v2,
207 _mm512_permutexvar_epi32(_mm512_set1_epi32(2), bias_data));
208 accum_data_v3 = _mm512_add_epi32(
209 accum_data_v3,
210 _mm512_permutexvar_epi32(_mm512_set1_epi32(3), bias_data));
211 accum_data_v4 = _mm512_add_epi32(
212 accum_data_v4,
213 _mm512_permutexvar_epi32(_mm512_set1_epi32(4), bias_data));
214 accum_data_v5 = _mm512_add_epi32(
215 accum_data_v5,
216 _mm512_permutexvar_epi32(_mm512_set1_epi32(5), bias_data));
217 accum_data_v6 = _mm512_add_epi32(
218 accum_data_v6,
219 _mm512_permutexvar_epi32(_mm512_set1_epi32(6), bias_data));
220 accum_data_v7 = _mm512_add_epi32(
221 accum_data_v7,
222 _mm512_permutexvar_epi32(_mm512_set1_epi32(7), bias_data));
223 accum_data_v8 = _mm512_add_epi32(
224 accum_data_v8,
225 _mm512_permutexvar_epi32(_mm512_set1_epi32(8), bias_data));
226 accum_data_v9 = _mm512_add_epi32(
227 accum_data_v9,
228 _mm512_permutexvar_epi32(_mm512_set1_epi32(9), bias_data));
229 accum_data_va = _mm512_add_epi32(
230 accum_data_va,
231 _mm512_permutexvar_epi32(_mm512_set1_epi32(10), bias_data));
232 accum_data_vb = _mm512_add_epi32(
233 accum_data_vb,
234 _mm512_permutexvar_epi32(_mm512_set1_epi32(11), bias_data));
235 accum_data_vc = _mm512_add_epi32(
236 accum_data_vc,
237 _mm512_permutexvar_epi32(_mm512_set1_epi32(12), bias_data));
238 accum_data_vd = _mm512_add_epi32(
239 accum_data_vd,
240 _mm512_permutexvar_epi32(_mm512_set1_epi32(13), bias_data));
241 accum_data_ve = _mm512_add_epi32(
242 accum_data_ve,
243 _mm512_permutexvar_epi32(_mm512_set1_epi32(14), bias_data));
244 accum_data_vf = _mm512_add_epi32(
245 accum_data_vf,
246 _mm512_permutexvar_epi32(_mm512_set1_epi32(15), bias_data));
247 }
248
249 const std::int8_t* lhs_ptr = lhs_col_ptr;
250 const std::int8_t* rhs_ptr = rhs_col_ptr;
251 for (int d = 0; d < params.depth; d += 4) {
252 const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
253 __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr);
254
255 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
256 std::int32_t rhs_data[32];
257 const __m256i rhs_data_bottom_lane =
258 _mm512_castsi512_si256(rhs_data_8bit);
259 const __m256i rhs_data_top_lane =
260 _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
261 const __m512i rhs_16_bit_dup_low =
262 _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
263 const __m512i rhs_16_bit_dup_high =
264 _mm512_cvtepi8_epi16(rhs_data_top_lane);
265 // Now that we have cast the RHS data, we store it so that each value
266 // can be separately loaded in the accumulation loop.
267 _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data),
268 rhs_16_bit_dup_low);
269 _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16),
270 rhs_16_bit_dup_high);
271
272 // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
273 const __m512i lhs_16_bit_low =
274 _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
275 // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
276 const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
277 _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
278
279 auto process_column = [=](int col, __m512i& accum) {
280 const __m512i rhs_16_bit_dup_low =
281 _mm512_set1_epi32(rhs_data[2 * col]);
282 const __m512i rhs_16_bit_dup_high =
283 _mm512_set1_epi32(rhs_data[2 * col + 1]);
284
285 accum = _mm512_add_epi32(
286 accum, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
287 accum = _mm512_add_epi32(
288 accum, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
289 };
290 process_column(0, accum_data_v0);
291 process_column(1, accum_data_v1);
292 process_column(2, accum_data_v2);
293 process_column(3, accum_data_v3);
294 process_column(4, accum_data_v4);
295 process_column(5, accum_data_v5);
296 process_column(6, accum_data_v6);
297 process_column(7, accum_data_v7);
298 process_column(8, accum_data_v8);
299 process_column(9, accum_data_v9);
300 process_column(10, accum_data_va);
301 process_column(11, accum_data_vb);
302 process_column(12, accum_data_vc);
303 process_column(13, accum_data_vd);
304 process_column(14, accum_data_ve);
305 process_column(15, accum_data_vf);
306
307 lhs_ptr += 16 * 4;
308 rhs_ptr += 16 * 4;
309 }
310
311 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
312 // The non-per-channel case could equivalently be handled in the per-row
313 // or per-column code path. The per-row code path is slightly more
314 // efficient so we handle it there.
315 const bool per_column_multiplier =
316 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
317 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
318
319 __m512i m_vector;
320 __m512i e_vector;
321 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
322 m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
323 params.multiplier_fixedpoint + multiplier_channel));
324 e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
325 params.multiplier_exponent + multiplier_channel));
326
327 const __m512i m_64bit_low =
328 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
329 const __m512i m_64bit_high =
330 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
331
332 const __m512i zero_vector = _mm512_setzero_epi32();
333 const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
334 const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
335 const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
336 const __m512i final_right_shift = _mm512_set1_epi32(31);
337 const __m512i right_shift_low =
338 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
339 const __m512i right_shift_high =
340 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
341 const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
342 _mm512_extracti32x8_epi32(final_right_shift, 0));
343 const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
344 _mm512_extracti32x8_epi32(final_right_shift, 1));
345
346 // A "half" added for rounding prior to truncation of 64-bit value.
347 const __m512i offset_vector =
348 _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
349
350 auto rounding_right_shift = [=](__m512i& results,
351 const __m512i& exponent) {
352 // Construct the "nudge" value for each lane if the exponent is
353 // greater than 0. Otherwise, the nudge is 0.
354 const __m512i zeros = _mm512_setzero_si512();
355 const auto mask_rightshift_gtz =
356 _mm512_cmpgt_epi64_mask(exponent, zeros);
357 const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
358 _mm512_set1_epi64(1),
359 _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
360 __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
361 one_shift_exp_minus1);
362 // Calculate the shifted sum (results + nudge) >> exp.
363 const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
364 const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
365
366 // Identify overflow in each lane and create mask.
367 const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
368 _mm512_set1_epi64(1),
369 _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
370 const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
371 results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
372 // Fill results with either (results + nudge) >> exponent or
373 // 1 << (31 - exp) in the case of overflow.
374 results = _mm512_mask_mov_epi64(
375 shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
376 };
377
378 if (per_column_multiplier) {
379 auto apply_multiplier = [=](__m512i& accum, int col) {
380 __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8);
381 // Apply the fixed-point part of the multiplier.
382 __m512i left_shift_val =
383 _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
384 __m512i m_64bit_val = _mm512_permutexvar_epi64(
385 perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
386 __m512i offset_vector_val =
387 _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector);
388 __m512i final_right_shift_val = _mm512_permutexvar_epi64(
389 perm_64bit_vals,
390 col < 8 ? final_right_shift_low : final_right_shift_high);
391 __m512i right_shift_val = _mm512_permutexvar_epi64(
392 perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high);
393
394 accum = _mm512_sllv_epi32(accum, left_shift_val);
395 __m512i scaled_v_low = _mm512_mul_epi32(
396 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
397 m_64bit_val);
398 __m512i scaled_v_high = _mm512_mul_epi32(
399 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
400 m_64bit_val);
401
402 scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val);
403 scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val);
404
405 scaled_v_low =
406 _mm512_srav_epi64(scaled_v_low, final_right_shift_val);
407 scaled_v_high =
408 _mm512_srav_epi64(scaled_v_high, final_right_shift_val);
409
410 rounding_right_shift(scaled_v_low, right_shift_val);
411 rounding_right_shift(scaled_v_high, right_shift_val);
412
413 accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
414 accum = _mm512_inserti32x8(accum,
415 _mm512_cvtepi64_epi32(scaled_v_high), 1);
416 };
417 apply_multiplier(accum_data_v0, 0);
418 apply_multiplier(accum_data_v1, 1);
419 apply_multiplier(accum_data_v2, 2);
420 apply_multiplier(accum_data_v3, 3);
421 apply_multiplier(accum_data_v4, 4);
422 apply_multiplier(accum_data_v5, 5);
423 apply_multiplier(accum_data_v6, 6);
424 apply_multiplier(accum_data_v7, 7);
425 apply_multiplier(accum_data_v8, 8);
426 apply_multiplier(accum_data_v9, 9);
427 apply_multiplier(accum_data_va, 10);
428 apply_multiplier(accum_data_vb, 11);
429 apply_multiplier(accum_data_vc, 12);
430 apply_multiplier(accum_data_vd, 13);
431 apply_multiplier(accum_data_ve, 14);
432 apply_multiplier(accum_data_vf, 15);
433 } else { // not per-column, so per-row
434 auto apply_multiplier = [=](__m512i& accum) {
435 accum = _mm512_sllv_epi32(accum, left_shift);
436 // Apply the fixed-point part of the multiplier.
437 __m512i scaled_v_low = _mm512_mul_epi32(
438 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
439 m_64bit_low);
440 __m512i scaled_v_high = _mm512_mul_epi32(
441 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
442 m_64bit_high);
443
444 scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
445 scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
446
447 scaled_v_low =
448 _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
449 scaled_v_high =
450 _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
451
452 rounding_right_shift(scaled_v_low, right_shift_low);
453 rounding_right_shift(scaled_v_high, right_shift_high);
454 accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
455 accum = _mm512_inserti32x8(accum,
456 _mm512_cvtepi64_epi32(scaled_v_high), 1);
457 };
458 apply_multiplier(accum_data_v0);
459 apply_multiplier(accum_data_v1);
460 apply_multiplier(accum_data_v2);
461 apply_multiplier(accum_data_v3);
462 apply_multiplier(accum_data_v4);
463 apply_multiplier(accum_data_v5);
464 apply_multiplier(accum_data_v6);
465 apply_multiplier(accum_data_v7);
466 apply_multiplier(accum_data_v8);
467 apply_multiplier(accum_data_v9);
468 apply_multiplier(accum_data_va);
469 apply_multiplier(accum_data_vb);
470 apply_multiplier(accum_data_vc);
471 apply_multiplier(accum_data_vd);
472 apply_multiplier(accum_data_ve);
473 apply_multiplier(accum_data_vf);
474 }
475
476 if (params.dst_zero_point != 0) {
477 __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
478 accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
479 accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point);
480 accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point);
481 accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point);
482 accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point);
483 accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point);
484 accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point);
485 accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point);
486 accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point);
487 accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point);
488 accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point);
489 accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point);
490 accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point);
491 accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point);
492 accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point);
493 accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point);
494 }
495 }
496
497 const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
498 const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
499
500 const bool store_full_block =
501 (residual_rows == 16) && (residual_cols == 16);
502
503 __m512i accum_data_v[16];
504
505 // In most cases we would make this conditional on (!store_full_block) and
506 // unwind the clamp-and-store loop, but the benefit appears small.
507 {
508 accum_data_v[0] = accum_data_v0;
509 accum_data_v[1] = accum_data_v1;
510 accum_data_v[2] = accum_data_v2;
511 accum_data_v[3] = accum_data_v3;
512 accum_data_v[4] = accum_data_v4;
513 accum_data_v[5] = accum_data_v5;
514 accum_data_v[6] = accum_data_v6;
515 accum_data_v[7] = accum_data_v7;
516 accum_data_v[8] = accum_data_v8;
517 accum_data_v[9] = accum_data_v9;
518 accum_data_v[10] = accum_data_va;
519 accum_data_v[11] = accum_data_vb;
520 accum_data_v[12] = accum_data_vc;
521 accum_data_v[13] = accum_data_vd;
522 accum_data_v[14] = accum_data_ve;
523 accum_data_v[15] = accum_data_vf;
524 }
525
526 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
527 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
528 const int block_col_offset = dst_stride;
529 if (store_full_block) {
530 for (int j = 0; j < 16; ++j) {
531 __m512i result = accum_data_v[j];
532 result = _mm512_min_epi32(result, clamp_max_v);
533 result = _mm512_max_epi32(result, clamp_min_v);
534 _mm_storeu_si128(
535 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
536 _mm512_cvtepi32_epi8(result));
537 }
538 } else {
539 for (int j = 0; j < residual_cols; ++j) {
540 __m512i result = accum_data_v[j];
541 result = _mm512_min_epi32(result, clamp_max_v);
542 result = _mm512_max_epi32(result, clamp_min_v);
543 _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
544 _mm512_cvtepi32_epi8(result));
545 }
546 }
547 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
548 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
549 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
550 const int block_col_offset = dst_stride;
551 if (store_full_block) {
552 for (int j = 0; j < residual_cols; ++j) {
553 __m512i result = accum_data_v[j];
554 result = _mm512_min_epi32(result, clamp_max_v);
555 result = _mm512_max_epi32(result, clamp_min_v);
556 _mm_storeu_si128(
557 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
558 _mm512_cvtepi32_epi8(result));
559 }
560 } else {
561 for (int j = 0; j < residual_cols; ++j) {
562 __m512i result = accum_data_v[j];
563 result = _mm512_min_epi32(result, clamp_max_v);
564 result = _mm512_max_epi32(result, clamp_min_v);
565 _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
566 _mm512_cvtepi32_epi8(result));
567 }
568 }
569 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
570 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
571 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
572 const int block_col_offset = dst_stride;
573 if (store_full_block) {
574 for (int j = 0; j < 16; ++j) {
575 __m512i result = accum_data_v[j];
576 result = _mm512_min_epi32(result, clamp_max_v);
577 result = _mm512_max_epi32(result, clamp_min_v);
578 _mm256_storeu_si256(
579 reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset),
580 _mm512_cvtepi32_epi16(result));
581 }
582 } else {
583 for (int j = 0; j < residual_cols; ++j) {
584 __m512i result = accum_data_v[j];
585 result = _mm512_min_epi32(result, clamp_max_v);
586 result = _mm512_max_epi32(result, clamp_min_v);
587 _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask,
588 _mm512_cvtepi32_epi16(result));
589 }
590 }
591 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
592 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
593 if (store_full_block) {
594 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
595 for (int j = 0; j < 16; ++j) {
596 _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]);
597 }
598 } else {
599 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
600 for (int j = 0; j < residual_cols; ++j) {
601 _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask,
602 accum_data_v[j]);
603 }
604 }
605 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
606 } else {
607 RUY_DCHECK(false);
608 }
609
610 lhs_col_ptr += 16 * params.lhs_stride;
611 } // End row-block loop.
612
613 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
614 16 * params.dst_stride);
615 rhs_col_ptr += 16 * params.rhs_stride;
616 } // End col-block loop.
617 } // NOLINT(readability/fn_size)
618
619 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
620 profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV");
621
622 RUY_DCHECK_EQ(params.dst_cols, 1);
623 RUY_DCHECK_EQ(params.last_col, 0);
624 RUY_DCHECK_EQ(params.start_col, 0);
625
626 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
627
628 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
629 void* dst_col_ptr = params.dst_base_ptr;
630 const std::int32_t* bias_col_ptr = params.bias;
631 if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
632 bias_col_ptr += params.start_row;
633 }
634
635 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
636 void* dst_ptr = dst_col_ptr;
637 const std::int32_t* bias_ptr = bias_col_ptr;
638
639 const std::int32_t lhs_zero_point = params.lhs_zero_point;
640 const bool has_rhs_sums_offsets =
641 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
642 std::int32_t rhs_sums_offsets[16];
643 if (has_rhs_sums_offsets) {
644 const __m512i rhs_sums_offset_v =
645 _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
646 _mm512_loadu_si512(¶ms.rhs_sums[0]));
647 _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
648 rhs_sums_offset_v);
649 }
650
651 for (int row = params.start_row; row <= params.last_row; row += 16) {
652 const int residual_rows = std::min(params.dst_rows - row, 16);
653
654 __m512i accum_data_v0;
655
656 // Initialize with bias.
657 const __mmask16 row_mask =
658 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
659 __m512i initial_accum_data =
660 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(bias_ptr));
661 bias_ptr += bias_ptr_block_increment;
662
663 const std::int32_t rhs_zero_point = params.rhs_zero_point;
664 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
665 const __m512i lhs_sums_offset =
666 _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
667 _mm512_loadu_si512(¶ms.lhs_sums[row]));
668 initial_accum_data =
669 _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
670 }
671
672 const std::int32_t prod_zp_depth = params.prod_zp_depth;
673 if (prod_zp_depth != 0) {
674 initial_accum_data = _mm512_add_epi32(initial_accum_data,
675 _mm512_set1_epi32(prod_zp_depth));
676 }
677
678 // Adjustments differing across columns.
679 if (has_rhs_sums_offsets) {
680 accum_data_v0 = _mm512_sub_epi32(initial_accum_data,
681 _mm512_set1_epi32(rhs_sums_offsets[0]));
682 } else {
683 accum_data_v0 = initial_accum_data;
684 }
685
686 const std::int8_t* lhs_ptr = lhs_col_ptr;
687 const std::int8_t* rhs_ptr = rhs_col_ptr;
688 for (int d = 0; d < params.depth; d += 4) {
689 const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
690 const __m128i rhs_data_8bit =
691 _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
692
693 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
694 // For simplicity we load 4x the data that we need and process twice the
695 // data that we need and store only the data we need.
696 std::int32_t rhs_data[2];
697 const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
698 // Now that we have cast the RHS data, we store it so that each value
699 // can be separately loaded in the accumulation loop.
700 _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
701
702 // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
703 const __m512i lhs_16_bit_low =
704 _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
705 // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
706 const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
707 _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
708
709 // Process column 0.
710 __m512i accum_v = accum_data_v0;
711 constexpr int index = 0;
712
713 const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
714 const __m512i rhs_16_bit_dup_high =
715 _mm512_set1_epi32(rhs_data[index + 1]);
716
717 accum_v = _mm512_add_epi32(
718 accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
719 accum_v = _mm512_add_epi32(
720 accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
721 accum_data_v0 = accum_v;
722
723 lhs_ptr += 16 * 4;
724 rhs_ptr += 16 * 4;
725 }
726
727 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
728 __m512i m_vector;
729 __m512i e_vector;
730 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
731 int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
732 m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
733 params.multiplier_fixedpoint + channel));
734 e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
735 params.multiplier_exponent + channel));
736
737 const __m512i m_64bit_low =
738 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
739 const __m512i m_64bit_high =
740 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
741
742 const __m512i zero_vector = _mm512_setzero_epi32();
743 const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
744 const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
745 const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
746 const __m512i final_right_shift = _mm512_set1_epi32(31);
747 const __m512i right_shift_low =
748 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
749 const __m512i right_shift_high =
750 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
751 const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
752 _mm512_extracti32x8_epi32(final_right_shift, 0));
753 const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
754 _mm512_extracti32x8_epi32(final_right_shift, 1));
755
756 // A "half" added for rounding prior to truncation of 64-bit value.
757 const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
758
759 auto rounding_right_shift = [=](__m512i& results,
760 const __m512i& exponent) {
761 // Construct the "nudge" value for each lane if the exponent is
762 // greater than 0. Otherwise, the nudge is 0.
763 const __m512i zeros = _mm512_setzero_si512();
764 const auto mask_rightshift_gtz =
765 _mm512_cmpgt_epi64_mask(exponent, zeros);
766 const __m512i one_shift_exp_minus1 =
767 _mm512_sllv_epi64(_mm512_set1_epi64(1),
768 _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
769 __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
770 one_shift_exp_minus1);
771 // Calculate the shifted sum (results + nudge) >> exp.
772 const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
773 const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
774
775 // Identify overflow in each lane and create mask.
776 const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
777 _mm512_set1_epi64(1),
778 _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
779 const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
780 results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
781 // Fill results with either (results + nudge) >> exponent or
782 // 1 << (31 - exp) in the case of overflow.
783 results = _mm512_mask_mov_epi64(
784 shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
785 };
786
787 // Shift and round column 0.
788 accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
789 // Apply the fixed-point part of the multiplier.
790 __m512i scaled_v_low = _mm512_mul_epi32(
791 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)),
792 m_64bit_low);
793 __m512i scaled_v_high = _mm512_mul_epi32(
794 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
795 m_64bit_high);
796
797 scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
798 scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
799
800 scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
801 scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
802
803 rounding_right_shift(scaled_v_low, right_shift_low);
804 rounding_right_shift(scaled_v_high, right_shift_high);
805
806 accum_data_v0 =
807 _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
808 accum_data_v0 = _mm512_inserti32x8(
809 accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
810
811 if (params.dst_zero_point != 0) {
812 __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
813 accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
814 }
815 }
816
817 const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
818 const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
819
820 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
821 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
822 __m512i result = accum_data_v0;
823 result = _mm512_min_epi32(result, clamp_max_v);
824 result = _mm512_max_epi32(result, clamp_min_v);
825 _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
826 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
827 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
828 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
829 __m512i result = accum_data_v0;
830 result = _mm512_min_epi32(result, clamp_max_v);
831 result = _mm512_max_epi32(result, clamp_min_v);
832 _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
833 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
834 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
835 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
836 __m512i result = accum_data_v0;
837 result = _mm512_min_epi32(result, clamp_max_v);
838 result = _mm512_max_epi32(result, clamp_min_v);
839 _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
840 _mm512_cvtepi32_epi16(result));
841 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
842 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
843 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
844 _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0);
845 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
846 } else {
847 RUY_DCHECK(false);
848 }
849
850 lhs_col_ptr += 16 * params.lhs_stride;
851 } // End row-block loop.
852 } // NOLINT(readability/fn_size)
853
854 void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
855 profiler::ScopeLabel label("Kernel kAvx512 float");
856
857 // As parameters are defined, we need to scale by sizeof(float).
858 const std::int64_t lhs_stride = params.lhs_stride >> 2;
859 const std::int64_t dst_stride = params.dst_stride >> 2;
860 const std::int64_t rhs_stride = params.rhs_stride >> 2;
861
862 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
863 const int end_row = std::min(params.dst_rows, params.last_row + 16);
864 const int end_col = std::min(params.dst_cols, params.last_col + 16);
865
866 const float* adj_rhs_col_ptr =
867 params.rhs_base_ptr - params.start_col * rhs_stride;
868 float* adj_dst_col_ptr =
869 params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
870 const float* adj_lhs_col_ptr =
871 params.lhs_base_ptr - params.start_row * lhs_stride;
872 const float* bias_ptr = params.bias;
873
874 const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
875 const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
876 const bool channel_dimension_is_col =
877 params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
878
879 int col = params.start_col;
880 for (; col <= end_col - 16; col += 16) {
881 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
882 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
883
884 int row = params.start_row;
885 for (; row <= end_row - 16; row += 16) {
886 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
887 float* dst_ptr = dst_col_ptr + row;
888
889 // Process block in two halves, split by columns.
890 #pragma unroll(1)
891 for (int mmm = 0; mmm < 2; ++mmm) {
892 __m512 accum_data_v0;
893 __m512 accum_data_v1;
894 __m512 accum_data_v2;
895 __m512 accum_data_v3;
896 __m512 accum_data_v4;
897 __m512 accum_data_v5;
898 __m512 accum_data_v6;
899 __m512 accum_data_v7;
900
901 // Initialize with bias.
902 if (channel_dimension_is_col) {
903 const float* bias_elem_ptr =
904 bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
905 accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
906 accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
907 accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
908 accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
909 accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
910 accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
911 accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
912 accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
913 } else {
914 const __m512 initial_accum_data =
915 _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
916
917 accum_data_v0 = initial_accum_data;
918 accum_data_v1 = initial_accum_data;
919 accum_data_v2 = initial_accum_data;
920 accum_data_v3 = initial_accum_data;
921 accum_data_v4 = initial_accum_data;
922 accum_data_v5 = initial_accum_data;
923 accum_data_v6 = initial_accum_data;
924 accum_data_v7 = initial_accum_data;
925 }
926
927 const float* lhs_ptr = lhs_col_ptr;
928 const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
929 for (int d = 0; d < (params.depth - 1); ++d) {
930 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
931 const float* rhs_data = rhs_ptr;
932 lhs_ptr += 16;
933 rhs_ptr += 16;
934
935 // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
936 // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
937 // so if given an rvalue.
938 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
939 accum_data_v0);
940 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
941 accum_data_v1);
942 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
943 accum_data_v2);
944 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
945 accum_data_v3);
946 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
947 accum_data_v4);
948 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
949 accum_data_v5);
950 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
951 accum_data_v6);
952 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
953 accum_data_v7);
954 }
955 { // nested extra blocks lead to measurable speed gains
956 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
957 const float* rhs_data = rhs_ptr;
958 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
959 accum_data_v0);
960 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
961 accum_data_v1);
962 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
963 accum_data_v2);
964 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
965 accum_data_v3);
966 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
967 accum_data_v4);
968 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
969 accum_data_v5);
970 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
971 accum_data_v6);
972 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
973 accum_data_v7);
974 {
975 float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
976 accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
977 accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
978 _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
979 accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
980 accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
981 _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
982 accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
983 accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
984 _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
985 accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
986 accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
987 _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
988 accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
989 accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
990 _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
991 accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
992 accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
993 _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
994 accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
995 accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
996 _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
997 accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
998 accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
999 _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
1000 }
1001 }
1002 }
1003 } // End row-block loop.
1004
1005 // The unrolling within this conditional may be somewhat pointless. It
1006 // depends on the kinds of models.
1007 if (row < end_row) {
1008 const int residual_rows = end_row - row;
1009
1010 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1011 float* dst_ptr = dst_col_ptr + row;
1012
1013 const __mmask16 row_mask =
1014 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1015
1016 // Process block in two halves, split by columns.
1017 for (int mmm = 0; mmm < 2; ++mmm) {
1018 __m512 accum_data_v0;
1019 __m512 accum_data_v1;
1020 __m512 accum_data_v2;
1021 __m512 accum_data_v3;
1022 __m512 accum_data_v4;
1023 __m512 accum_data_v5;
1024 __m512 accum_data_v6;
1025 __m512 accum_data_v7;
1026
1027 // Initialize with bias.
1028 if (channel_dimension_is_col) {
1029 const float* bias_elem_ptr =
1030 bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1031 accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
1032 accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
1033 accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
1034 accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
1035 accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
1036 accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
1037 accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
1038 accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
1039 } else {
1040 const __m512 initial_accum_data =
1041 _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1042
1043 accum_data_v0 = initial_accum_data;
1044 accum_data_v1 = initial_accum_data;
1045 accum_data_v2 = initial_accum_data;
1046 accum_data_v3 = initial_accum_data;
1047 accum_data_v4 = initial_accum_data;
1048 accum_data_v5 = initial_accum_data;
1049 accum_data_v6 = initial_accum_data;
1050 accum_data_v7 = initial_accum_data;
1051 }
1052
1053 const float* lhs_ptr = lhs_col_ptr;
1054 const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1055 for (int d = 0; d < (params.depth - 1); ++d) {
1056 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1057 const float* rhs_data = rhs_ptr;
1058 lhs_ptr += 16;
1059 rhs_ptr += 16;
1060 // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
1061 // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
1062 // so if given an rvalue.
1063 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1064 accum_data_v0);
1065 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1066 accum_data_v1);
1067 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1068 accum_data_v2);
1069 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1070 accum_data_v3);
1071 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1072 accum_data_v4);
1073 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1074 accum_data_v5);
1075 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1076 accum_data_v6);
1077 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1078 accum_data_v7);
1079 }
1080 {
1081 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1082 const float* rhs_data = rhs_ptr;
1083 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1084 accum_data_v0);
1085 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1086 accum_data_v1);
1087 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1088 accum_data_v2);
1089 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1090 accum_data_v3);
1091 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1092 accum_data_v4);
1093 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1094 accum_data_v5);
1095 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1096 accum_data_v6);
1097 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1098 accum_data_v7);
1099 {
1100 float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
1101 accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
1102 accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
1103 _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
1104 accum_data_v0);
1105 accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
1106 accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
1107 _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
1108 accum_data_v1);
1109 accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
1110 accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
1111 _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
1112 accum_data_v2);
1113 accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
1114 accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
1115 _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
1116 accum_data_v3);
1117 accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
1118 accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
1119 _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
1120 accum_data_v4);
1121 accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
1122 accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
1123 _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
1124 accum_data_v5);
1125 accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
1126 accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
1127 _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
1128 accum_data_v6);
1129 accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
1130 accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
1131 _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
1132 accum_data_v7);
1133 }
1134 }
1135 } // Inner half-block loop.
1136 } // Residual rows, main col-block loop.
1137 } // End col-block loop.
1138
1139 if (col < end_col) {
1140 RUY_DCHECK_GE(end_col - col, 0);
1141 RUY_DCHECK_LT(end_col - col, 16);
1142
1143 __m512 accum_data_v[8];
1144
1145 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
1146 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
1147
1148 for (int row = params.start_row; row < end_row; row += 16) {
1149 const int residual_rows = std::min(end_row - row, 16);
1150
1151 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1152 float* dst_ptr = dst_col_ptr + row;
1153
1154 const __mmask16 row_mask =
1155 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1156
1157 // Process block in two halves, split by columns.
1158 for (int mmm = 0; mmm < 2; ++mmm) {
1159 // Initialize with bias.
1160 if (channel_dimension_is_col) {
1161 const float* bias_elem_ptr =
1162 bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1163 for (int j = 0; j < 8; ++j) {
1164 accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]);
1165 }
1166 } else {
1167 const __m512 initial_accum_data =
1168 _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1169 for (int j = 0; j < 8; ++j) {
1170 accum_data_v[j] = initial_accum_data;
1171 }
1172 }
1173
1174 const float* lhs_ptr = lhs_col_ptr;
1175 const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1176 for (int d = 0; d < params.depth; ++d) {
1177 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1178 const float* rhs_data = rhs_ptr;
1179
1180 for (int j = 0; j < 8; ++j) {
1181 const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
1182 accum_data_v[j] =
1183 _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
1184 }
1185 lhs_ptr += 16;
1186 rhs_ptr += 16;
1187 }
1188
1189 const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
1190
1191 if (residual_rows == 16) {
1192 if (residual_cols == 8) {
1193 for (int j = 0; j < 8; ++j) {
1194 float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1195 accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1196 accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1197 _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1198 }
1199 } else {
1200 for (int j = 0; j < residual_cols; ++j) {
1201 float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1202 accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1203 accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1204 _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1205 }
1206 }
1207 } else {
1208 for (int j = 0; j < residual_cols; ++j) {
1209 float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1210 accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1211 accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1212 _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
1213 }
1214 }
1215 } // Inner half-block loop.
1216 } // End row-block loop.
1217 } // Residual cols.
1218 }
1219
1220 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
1221 profiler::ScopeLabel label("Kernel kAvx512 float GEMV");
1222
1223 RUY_DCHECK_EQ(params.dst_cols, 1);
1224 RUY_DCHECK_EQ(params.last_col, 0);
1225 RUY_DCHECK_EQ(params.start_col, 0);
1226
1227 // As parameters are defined, we need to scale by sizeof(float).
1228 const std::int64_t lhs_stride = params.lhs_stride >> 2;
1229
1230 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
1231 const int end_row = std::min(params.dst_rows, params.last_row + 16);
1232
1233 float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
1234 const float* adj_lhs_col_ptr =
1235 params.lhs_base_ptr - params.start_row * lhs_stride;
1236 const float* bias_col_ptr = params.bias;
1237
1238 const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
1239 const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
1240
1241 __m512 accum_data_v;
1242
1243 const float* rhs_col_ptr = params.rhs_base_ptr;
1244 float* dst_col_ptr = adj_dst_col_ptr;
1245
1246 int row = params.start_row;
1247 for (; row <= end_row - 16; row += 16) {
1248 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1249 float* dst_ptr = dst_col_ptr + row;
1250 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1251
1252 // Initialize with bias.
1253 accum_data_v = _mm512_loadu_ps(bias_ptr);
1254
1255 const float* lhs_ptr = lhs_col_ptr;
1256 const float* rhs_ptr = rhs_col_ptr;
1257 for (int d = 0; d < params.depth; ++d) {
1258 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1259 const float rhs_data = *rhs_ptr;
1260
1261 const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1262 accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1263 lhs_ptr += 16;
1264 rhs_ptr += 16;
1265 }
1266
1267 accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1268 accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1269 _mm512_storeu_ps(dst_ptr, accum_data_v);
1270 } // End row-block loop.
1271
1272 if (row < end_row) {
1273 const int residual_rows = end_row - row;
1274 RUY_CHECK_GE(residual_rows, 1);
1275 RUY_CHECK_LT(residual_rows, 16);
1276
1277 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1278 float* dst_ptr = dst_col_ptr + row;
1279 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1280
1281 // Initialize with bias.
1282 const __mmask16 row_mask =
1283 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1284 accum_data_v = _mm512_loadu_ps(bias_ptr);
1285
1286 const float* lhs_ptr = lhs_col_ptr;
1287 const float* rhs_ptr = rhs_col_ptr;
1288 for (int d = 0; d < params.depth; ++d) {
1289 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1290 const float rhs_data = *rhs_ptr;
1291
1292 const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1293 accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1294 lhs_ptr += 16;
1295 rhs_ptr += 16;
1296 }
1297
1298 accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1299 accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1300 _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v);
1301 } // End handling of residual rows.
1302 }
1303
1304 #endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
1305
1306 } // namespace ruy
1307