1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cstdint>
18 #include <cstring>
19
20 #include "ruy/check_macros.h"
21 #include "ruy/kernel_common.h"
22 #include "ruy/kernel_x86.h"
23 #include "ruy/opt_set.h"
24 #include "ruy/platform.h"
25 #include "ruy/profiler/instrumentation.h"
26
27 #if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
28 #include <immintrin.h> // IWYU pragma: keep
29 #endif
30
31 namespace ruy {
32
33 #if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
34
Kernel8bitAvx2(const KernelParams8bit<8,8> &)35 void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) {
36 // CPU-ID-based checks should disable the path that would reach this point.
37 RUY_DCHECK(false);
38 }
39
Kernel8bitAvx2SingleCol(const KernelParams8bit<8,8> &)40 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>&) {
41 // CPU-ID-based checks should disable the path that would reach this point.
42 RUY_DCHECK(false);
43 }
44
KernelFloatAvx2(const KernelParamsFloat<8,8> &)45 void KernelFloatAvx2(const KernelParamsFloat<8, 8>&) {
46 // CPU-ID-based checks should disable the path that would reach this point.
47 RUY_DCHECK(false);
48 }
49
KernelFloatAvx2SingleCol(const KernelParamsFloat<8,8> &)50 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) {
51 // CPU-ID-based checks should disable the path that would reach this point.
52 RUY_DCHECK(false);
53 }
54
55 #else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
56
57 static constexpr int kAvx8bitBlockSize = 8;
58 static constexpr int kAvx8bitInnerSize = 4;
59
60 namespace {
61 namespace intrin_utils {
62
63 template <>
64 inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a,
65 const __m256i& b) {
66 return _mm256_shuffle_epi8(a, b);
67 }
68
69 // Make an inline function for FMA so we can share the float kernels
70 // with non-FMA code.
71 template <>
72 inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b,
73 const __m256& c) {
74 return _mm256_fmadd_ps(a, b, c);
75 }
76
77 template <>
78 inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a,
79 const int imm) {
80 switch (imm) {
81 case 0:
82 return _mm256_extracti128_si256(a, 0);
83 case 1:
84 return _mm256_extracti128_si256(a, 1);
85 default:
86 RUY_DCHECK_LT(imm, 2);
87 return _mm_setzero_si128();
88 }
89 }
90
91 __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
92 const __m256i& mask) {
93 __m256 result =
94 _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
95 _mm256_castsi256_ps(mask));
96 return _mm256_castps_si256(result);
97 }
98
99 } // namespace intrin_utils
100 } // namespace
101
102 template <Path path>
103 void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
104 profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit");
105 const std::int8_t splitter_idx_data[32] = {
106 0, 1, 4, 5, 8, 9, 12, 13, //
107 2, 3, 6, 7, 10, 11, 14, 15, //
108 0, 1, 4, 5, 8, 9, 12, 13, //
109 2, 3, 6, 7, 10, 11, 14, 15 //
110 };
111
112 std::int32_t dst_stride = 0;
113 if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
114 (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
115 dst_stride = params.dst_stride;
116 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
117 dst_stride = params.dst_stride / sizeof(std::int16_t);
118 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
119 dst_stride = params.dst_stride / sizeof(std::int32_t);
120 } else {
121 RUY_DCHECK(false);
122 }
123
124 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
125 void* dst_col_ptr = params.dst_base_ptr;
126
127 for (int col = params.start_col; col <= params.last_col;
128 col += kAvx8bitBlockSize) {
129 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
130 void* dst_ptr = dst_col_ptr;
131
132 const std::int32_t lhs_zero_point = params.lhs_zero_point;
133 const bool has_rhs_sums_offsets =
134 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
135 std::int32_t rhs_sums_offsets[8];
136 if (has_rhs_sums_offsets) {
137 const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
138 _mm256_set1_epi32(lhs_zero_point),
139 _mm256_loadu_si256(
140 reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col])));
141 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
142 rhs_sums_offset_v);
143 }
144
145 for (int row = params.start_row; row <= params.last_row;
146 row += kAvx8bitBlockSize) {
147 int channel =
148 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
149 int multiplier_channel =
150 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
151 const int residual_rows =
152 std::min(params.dst_rows - row, kAvx8bitBlockSize);
153 const int residual_cols =
154 std::min(params.dst_cols - col, kAvx8bitBlockSize);
155
156 const __m256i splitter_idx = _mm256_loadu_si256(
157 reinterpret_cast<__m256i const*>(splitter_idx_data));
158
159 __m256i accum_data_v0;
160 __m256i accum_data_v1;
161 __m256i accum_data_v2;
162 __m256i accum_data_v3;
163 __m256i accum_data_v4;
164 __m256i accum_data_v5;
165 __m256i accum_data_v6;
166 __m256i accum_data_v7;
167
168 // initial_accum_data will be the initialize of each of the
169 // accum_data_* accumulator registers. We compute into it terms that are
170 // identical across columns.
171 __m256i initial_accum_data = _mm256_set1_epi32(params.prod_zp_depth);
172
173 // In the channels-are-rows case, we can load bias here.
174 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
175 !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
176 initial_accum_data = _mm256_add_epi32(
177 initial_accum_data,
178 _mm256_loadu_si256(
179 reinterpret_cast<const __m256i*>(params.bias + row)));
180 }
181
182 // Adjustments common across columns.
183 const std::int32_t rhs_zero_point = params.rhs_zero_point;
184 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
185 const __m256i lhs_sums_offset = _mm256_mullo_epi32(
186 _mm256_set1_epi32(rhs_zero_point),
187 _mm256_loadu_si256(
188 reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row])));
189 initial_accum_data =
190 _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
191 }
192
193 // Adjustments differing across columns.
194 if (has_rhs_sums_offsets) {
195 accum_data_v0 = _mm256_sub_epi32(
196 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
197 accum_data_v1 = _mm256_sub_epi32(
198 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
199 accum_data_v2 = _mm256_sub_epi32(
200 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
201 accum_data_v3 = _mm256_sub_epi32(
202 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
203 accum_data_v4 = _mm256_sub_epi32(
204 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
205 accum_data_v5 = _mm256_sub_epi32(
206 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
207 accum_data_v6 = _mm256_sub_epi32(
208 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
209 accum_data_v7 = _mm256_sub_epi32(
210 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
211 } else {
212 accum_data_v0 = initial_accum_data;
213 accum_data_v1 = initial_accum_data;
214 accum_data_v2 = initial_accum_data;
215 accum_data_v3 = initial_accum_data;
216 accum_data_v4 = initial_accum_data;
217 accum_data_v5 = initial_accum_data;
218 accum_data_v6 = initial_accum_data;
219 accum_data_v7 = initial_accum_data;
220 }
221
222 // Finally, in the channels-are-columns case, load bias data here.
223 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
224 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
225 const __m256i bias_data = _mm256_loadu_si256(
226 reinterpret_cast<const __m256i*>(params.bias + col));
227 accum_data_v0 = _mm256_add_epi32(
228 accum_data_v0,
229 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(0)));
230 accum_data_v1 = _mm256_add_epi32(
231 accum_data_v1,
232 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(1)));
233 accum_data_v2 = _mm256_add_epi32(
234 accum_data_v2,
235 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(2)));
236 accum_data_v3 = _mm256_add_epi32(
237 accum_data_v3,
238 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(3)));
239 accum_data_v4 = _mm256_add_epi32(
240 accum_data_v4,
241 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(4)));
242 accum_data_v5 = _mm256_add_epi32(
243 accum_data_v5,
244 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(5)));
245 accum_data_v6 = _mm256_add_epi32(
246 accum_data_v6,
247 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(6)));
248 accum_data_v7 = _mm256_add_epi32(
249 accum_data_v7,
250 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(7)));
251 }
252
253 const std::int8_t* lhs_ptr = lhs_col_ptr;
254 const std::int8_t* rhs_ptr = rhs_col_ptr;
255 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
256 const __m256i lhs_data =
257 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
258 const __m256i rhs_data_8bit =
259 _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
260
261 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
262 std::int32_t rhs_data[16];
263 const __m128i rhs_data_bottom_lane =
264 _mm256_castsi256_si128(rhs_data_8bit);
265 const __m128i rhs_data_top_lane =
266 _mm256_extracti128_si256(rhs_data_8bit, 1);
267 const __m256i rhs_16_bit_dup_low =
268 _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
269 const __m256i rhs_16_bit_dup_high =
270 _mm256_cvtepi8_epi16(rhs_data_top_lane);
271 // Now that we have cast the RHS data, we store it so that each value
272 // can be separately loaded in the accumulation loop.
273 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
274 rhs_16_bit_dup_low);
275 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
276 rhs_16_bit_dup_high);
277
278 const __m256i lhs_data_split =
279 _mm256_shuffle_epi8(lhs_data, splitter_idx);
280 const __m256i lhs_data_split_expand_bottom =
281 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
282 const __m256i lhs_data_split_expand_top =
283 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
284
285 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
286 const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
287 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
288 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
289 const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
290 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
291
292 __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
293 rhs_data)); // Load [0 1 2 3 4 5 6 7]
294 __m256i rhs1 = _mm256_lddqu_si256(
295 reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15]
296 __m256i rhs0_3 =
297 _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3]
298 __m256i rhs4_7 =
299 _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7]
300 __m256i rhs8_11 =
301 _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11]
302 __m256i rhs12_15 =
303 _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15]
304
305 auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
306 __m256i& accum) {
307 accum = _mm256_add_epi32(
308 accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_dup_lo));
309 accum = _mm256_add_epi32(
310 accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_dup_hi));
311 };
312 __m256i tmp0, tmp1, tmp2, tmp3;
313 tmp0 = _mm256_shuffle_epi32(rhs0_3, 0);
314 tmp1 = _mm256_shuffle_epi32(rhs0_3, 0x55);
315 process_column(tmp0, tmp1, accum_data_v0);
316 tmp2 = _mm256_shuffle_epi32(rhs0_3, 0xaa);
317 tmp3 = _mm256_shuffle_epi32(rhs0_3, 0xff);
318 process_column(tmp2, tmp3, accum_data_v1);
319
320 tmp0 = _mm256_shuffle_epi32(rhs4_7, 0);
321 tmp1 = _mm256_shuffle_epi32(rhs4_7, 0x55);
322 process_column(tmp0, tmp1, accum_data_v2);
323 tmp2 = _mm256_shuffle_epi32(rhs4_7, 0xaa);
324 tmp3 = _mm256_shuffle_epi32(rhs4_7, 0xff);
325 process_column(tmp2, tmp3, accum_data_v3);
326
327 tmp0 = _mm256_shuffle_epi32(rhs8_11, 0);
328 tmp1 = _mm256_shuffle_epi32(rhs8_11, 0x55);
329 process_column(tmp0, tmp1, accum_data_v4);
330 tmp2 = _mm256_shuffle_epi32(rhs8_11, 0xaa);
331 tmp3 = _mm256_shuffle_epi32(rhs8_11, 0xff);
332 process_column(tmp2, tmp3, accum_data_v5);
333
334 tmp0 = _mm256_shuffle_epi32(rhs12_15, 0);
335 tmp1 = _mm256_shuffle_epi32(rhs12_15, 0x55);
336 process_column(tmp0, tmp1, accum_data_v6);
337 tmp2 = _mm256_shuffle_epi32(rhs12_15, 0xaa);
338 tmp3 = _mm256_shuffle_epi32(rhs12_15, 0xff);
339 process_column(tmp2, tmp3, accum_data_v7);
340
341 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
342 rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
343 }
344
345 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
346 __m256i m_vector;
347 __m256i e_vector;
348 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
349 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
350 params.multiplier_fixedpoint + multiplier_channel));
351 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
352 params.multiplier_exponent + multiplier_channel));
353
354 const __m256i m_64bit_low =
355 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
356 const __m256i m_64bit_high =
357 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
358
359 const __m256i zero_vector = _mm256_setzero_si256();
360 const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
361 const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
362 const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
363 const __m256i final_right_shift = _mm256_set1_epi32(31);
364 const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(
365 _mm256_extracti128_si256(final_right_shift, 0));
366 const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(
367 _mm256_extracti128_si256(final_right_shift, 1));
368 const __m256i convert_to_unsigned_64 =
369 _mm256_set1_epi64x(0x8000000000000000);
370
371 __m256i post_scaling_offset = _mm256_setzero_si256();
372 // A "half" added for rounding prior to truncation of 64-bit value.
373 const __m256i offset_vector = _mm256_add_epi64(
374 _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
375 convert_to_unsigned_64);
376
377 if (params.dst_zero_point) {
378 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
379 }
380
381 const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
382
383 // We cannot do
384 //
385 // scaled_v_low =
386 // _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
387 // scaled_v_high =
388 // _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
389 //
390 // since this instruction is not in AVX2. Instead we use
391 // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
392 // offsets before (convert_to_unsigned_64) and after
393 // (convert_to_signed_halved).
394 //
395 // The overall process is, for 64-bit scaled accumulator:
396 // unsigned_accum = signed_accum + 1 << 63;
397 // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
398 // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
399
400 // There are various ways to repack the results, in the absence of
401 // _mm256_cvtepi64_epi32() or anything like it.
402 // A.
403 // accum_data_v[j] =
404 // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
405 // _mm256_extract_epi32(scaled_v_high, 4),
406 // _mm256_extract_epi32(scaled_v_high, 2),
407 // _mm256_extract_epi32(scaled_v_high, 0),
408 // _mm256_extract_epi32(scaled_v_low, 6),
409 // _mm256_extract_epi32(scaled_v_low, 4),
410 // _mm256_extract_epi32(scaled_v_low, 2),
411 // _mm256_extract_epi32(scaled_v_low, 0));
412 // B.
413 // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
414 // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
415 // accum_data_v[j] =
416 // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
417 // _mm256_extract_epi64(scaled_v_high, 0),
418 // _mm256_extract_epi64(scaled_v_low, 2),
419 // _mm256_extract_epi64(scaled_v_low, 0));
420 // C.
421 // scaled_v_low =
422 // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
423 // scaled_v_high =
424 // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
425 // accum_data_v[j] =
426 // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
427 //
428 // However, we choose the following because it uses two lighter
429 // instructions. The permutation does have a longer latency, but this
430 // loop can be unrolled.
431 // D.
432 // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
433 // __m256i results =
434 // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
435 // results = _mm256_permutevar8x32_epi32(results, repack_perm);
436 // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset);
437
438 // This multiplier code is complex and expensive enough on x86, that
439 // we prefer to implement the channels-are-columns case by transposing
440 // around it, rather than duplicate it (which would also require
441 // duplicating the above code computing the multiplier constants).
442 // This is one instance where channels-are-columns has lower performance
443 // than channels-are-rows.
444 const bool transpose_around_multiplier =
445 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
446 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
447 if (transpose_around_multiplier) {
448 // Transpose the 8x8 accumulators block. Will be un-transposed below
449 // after the multplier implementation.
450 intrin_utils::mm256_transpose8x8_epi32<path>(
451 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
452 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
453 }
454
455 auto rounding_right_shift = [=](__m256i& results,
456 const __m256i& exponent) {
457 // Construct the "nudge" value for each lane if the exponent is
458 // greater than 0. Otherwise, the nudge is 0.
459 const __m256i zeros = _mm256_setzero_si256();
460 const __m256i mask_rightshift_gtz =
461 _mm256_cmpgt_epi32(exponent, zeros);
462 const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
463 _mm256_set1_epi32(1),
464 _mm256_sub_epi32(exponent, _mm256_set1_epi32(1)));
465 __m256i nudge = intrin_utils::mm256_blendv_epi32(
466 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
467 // Calculate the shifted sum (results + nudge) >> exp.
468 const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
469 const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent);
470
471 // Identify overflow in each lane and create mask.
472 const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
473 _mm256_set1_epi32(1),
474 _mm256_sub_epi32(_mm256_set1_epi32(31), exponent));
475 const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
476 results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
477 // Fill results with either (results + nudge) >> exponent or
478 // 1 << (31 - exp) in the case of overflow.
479 results = intrin_utils::mm256_blendv_epi32(
480 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
481 };
482
483 auto apply_multiplier = [=](__m256i& accum) {
484 __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
485 // Apply the fixed-point part of the multiplier.
486 __m256i scaled_v_low = _mm256_mul_epi32(
487 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
488 m_64bit_low);
489 __m256i scaled_v_high = _mm256_mul_epi32(
490 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
491 m_64bit_high);
492
493 scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
494 scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
495
496 scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
497 scaled_v_high =
498 _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
499
500 scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
501 __m256i results =
502 _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
503 results = _mm256_permutevar8x32_epi32(results, repack_perm);
504 // Now do a Rounding Right Shift.
505 rounding_right_shift(results, right_shift);
506 accum = _mm256_add_epi32(results, post_scaling_offset);
507 };
508 apply_multiplier(accum_data_v0);
509 apply_multiplier(accum_data_v1);
510 apply_multiplier(accum_data_v2);
511 apply_multiplier(accum_data_v3);
512 apply_multiplier(accum_data_v4);
513 apply_multiplier(accum_data_v5);
514 apply_multiplier(accum_data_v6);
515 apply_multiplier(accum_data_v7);
516 // See above comment: here we transpose again to undo the transposition
517 // of the 8x8 block of accumulators used to implement the
518 // channels-are-columns case.
519 if (transpose_around_multiplier) {
520 intrin_utils::mm256_transpose8x8_epi32<path>(
521 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
522 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
523 }
524 }
525 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
526 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
527 const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
528 (residual_cols == kAvx8bitBlockSize);
529
530 __m256i accum_data_v[kAvx8bitBlockSize];
531 if (!store_full_block) {
532 accum_data_v[0] = accum_data_v0;
533 accum_data_v[1] = accum_data_v1;
534 accum_data_v[2] = accum_data_v2;
535 accum_data_v[3] = accum_data_v3;
536 accum_data_v[4] = accum_data_v4;
537 accum_data_v[5] = accum_data_v5;
538 accum_data_v[6] = accum_data_v6;
539 accum_data_v[7] = accum_data_v7;
540 }
541
542 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
543 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
544 if (store_full_block) {
545 accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
546 accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
547 accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
548 accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
549 accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
550 accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
551 accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
552 accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
553 accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
554 accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
555 accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
556 accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
557 accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
558 accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
559 accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
560 accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
561 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
562 &tmp_ptr[0 * dst_stride], accum_data_v0);
563 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
564 &tmp_ptr[1 * dst_stride], accum_data_v1);
565 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
566 &tmp_ptr[2 * dst_stride], accum_data_v2);
567 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
568 &tmp_ptr[3 * dst_stride], accum_data_v3);
569 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
570 &tmp_ptr[4 * dst_stride], accum_data_v4);
571 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
572 &tmp_ptr[5 * dst_stride], accum_data_v5);
573 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
574 &tmp_ptr[6 * dst_stride], accum_data_v6);
575 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
576 &tmp_ptr[7 * dst_stride], accum_data_v7);
577 } else {
578 for (int j = 0; j < residual_cols; ++j) {
579 __m256i result = accum_data_v[j];
580 result = _mm256_min_epi32(result, clamp_max_v);
581 result = _mm256_max_epi32(result, clamp_min_v);
582 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
583 tmp_ptr, residual_rows, result);
584 tmp_ptr += dst_stride;
585 }
586 }
587 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
588 kAvx8bitBlockSize);
589 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
590 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
591 if (store_full_block) {
592 accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
593 accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
594 accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
595 accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
596 accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
597 accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
598 accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
599 accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
600 accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
601 accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
602 accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
603 accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
604 accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
605 accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
606 accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
607 accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
608 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
609 accum_data_v0);
610 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
611 accum_data_v1);
612 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
613 &tmp_ptr[2 * dst_stride], accum_data_v2);
614 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
615 &tmp_ptr[3 * dst_stride], accum_data_v3);
616 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
617 &tmp_ptr[4 * dst_stride], accum_data_v4);
618 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
619 &tmp_ptr[5 * dst_stride], accum_data_v5);
620 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
621 &tmp_ptr[6 * dst_stride], accum_data_v6);
622 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
623 &tmp_ptr[7 * dst_stride], accum_data_v7);
624 } else {
625 for (int j = 0; j < residual_cols; ++j) {
626 __m256i result = accum_data_v[j];
627 result = _mm256_min_epi32(result, clamp_max_v);
628 result = _mm256_max_epi32(result, clamp_min_v);
629 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
630 tmp_ptr, residual_rows, result);
631 tmp_ptr += dst_stride;
632 }
633 }
634 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
635 kAvx8bitBlockSize);
636 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
637 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
638 if (store_full_block) {
639 accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
640 accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
641 accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
642 accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
643 accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
644 accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
645 accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
646 accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
647 accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
648 accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
649 accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
650 accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
651 accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
652 accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
653 accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
654 accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
655 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
656 accum_data_v0);
657 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
658 accum_data_v1);
659 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
660 &tmp_ptr[2 * dst_stride], accum_data_v2);
661 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
662 &tmp_ptr[3 * dst_stride], accum_data_v3);
663 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
664 &tmp_ptr[4 * dst_stride], accum_data_v4);
665 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
666 &tmp_ptr[5 * dst_stride], accum_data_v5);
667 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
668 &tmp_ptr[6 * dst_stride], accum_data_v6);
669 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
670 &tmp_ptr[7 * dst_stride], accum_data_v7);
671 } else {
672 for (int j = 0; j < residual_cols; ++j) {
673 __m256i result = accum_data_v[j];
674 result = _mm256_min_epi32(result, clamp_max_v);
675 result = _mm256_max_epi32(result, clamp_min_v);
676 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
677 tmp_ptr, residual_rows, result);
678 tmp_ptr += dst_stride;
679 }
680 }
681 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
682 kAvx8bitBlockSize);
683 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
684 if (store_full_block) {
685 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
686 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
687 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
688 accum_data_v1);
689 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
690 accum_data_v2);
691 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
692 accum_data_v3);
693 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
694 accum_data_v4);
695 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
696 accum_data_v5);
697 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
698 accum_data_v6);
699 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
700 accum_data_v7);
701 } else {
702 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
703 for (int j = 0; j < residual_cols; ++j) {
704 intrin_utils::mm256_n_storeu_epi32<path>(
705 dst_block_ptr, residual_rows, accum_data_v[j]);
706 dst_block_ptr += dst_stride;
707 }
708 }
709 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
710 kAvx8bitBlockSize);
711 } else {
712 RUY_DCHECK(false);
713 }
714
715 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
716 } // End row-block loop.
717
718 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
719 kAvx8bitBlockSize * params.dst_stride);
720 rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
721 } // End col-block loop.
722 } // NOLINT(readability/fn_size)
723
724 void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
725 Kernel8bitAvx2Impl<Path::kAvx2Fma>(params);
726 }
727
728 template <Path path>
729 void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
730 profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV");
731
732 RUY_DCHECK_EQ(params.dst_cols, 1);
733 RUY_DCHECK_EQ(params.last_col, 0);
734 RUY_DCHECK_EQ(params.start_col, 0);
735
736 const std::int8_t splitter_idx_data[32] = {
737 0, 1, 4, 5, 8, 9, 12, 13, //
738 2, 3, 6, 7, 10, 11, 14, 15, //
739 0, 1, 4, 5, 8, 9, 12, 13, //
740 2, 3, 6, 7, 10, 11, 14, 15 //
741 };
742
743 int bias_ptr_block_increment =
744 params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
745
746 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
747 void* dst_col_ptr = params.dst_base_ptr;
748 const std::int32_t* bias_col_ptr = params.bias;
749 if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
750 bias_col_ptr += params.start_row;
751 }
752
753 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
754 void* dst_ptr = dst_col_ptr;
755 const std::int32_t* bias_ptr = bias_col_ptr;
756
757 const std::int32_t lhs_zero_point = params.lhs_zero_point;
758 const bool has_rhs_sums_offsets =
759 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
760 std::int32_t rhs_sums_offsets[8];
761 if (has_rhs_sums_offsets) {
762 const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
763 _mm256_set1_epi32(lhs_zero_point),
764 _mm256_loadu_si256(
765 reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0])));
766 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
767 rhs_sums_offset_v);
768 }
769
770 for (int row = params.start_row; row <= params.last_row;
771 row += kAvx8bitBlockSize) {
772 const int residual_rows =
773 std::min(params.dst_rows - row, kAvx8bitBlockSize);
774
775 const __m256i splitter_idx =
776 _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
777
778 __m256i accum_data_v0;
779
780 // Initialize with bias.
781 __m256i initial_accum_data =
782 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
783 bias_ptr += bias_ptr_block_increment;
784
785 // Adjustments common across columns.
786 const std::int32_t rhs_zero_point = params.rhs_zero_point;
787 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
788 const __m256i lhs_sums_offset = _mm256_mullo_epi32(
789 _mm256_set1_epi32(rhs_zero_point),
790 _mm256_loadu_si256(
791 reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row])));
792 initial_accum_data =
793 _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
794 }
795 const std::int32_t prod_zp_depth = params.prod_zp_depth;
796 if (prod_zp_depth) {
797 initial_accum_data = _mm256_add_epi32(initial_accum_data,
798 _mm256_set1_epi32(prod_zp_depth));
799 }
800
801 // Adjustments differing across columns.
802 if (has_rhs_sums_offsets) {
803 accum_data_v0 = _mm256_sub_epi32(initial_accum_data,
804 _mm256_set1_epi32(rhs_sums_offsets[0]));
805 } else {
806 accum_data_v0 = initial_accum_data;
807 }
808
809 const std::int8_t* lhs_ptr = lhs_col_ptr;
810 const std::int8_t* rhs_ptr = rhs_col_ptr;
811 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
812 const __m256i lhs_data =
813 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
814 const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
815
816 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
817 // For simplicity we load 4x the data that we need and process twice the
818 // data that we need and store only the data we need.
819 std::int32_t rhs_data[2];
820 const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
821 // Now that we have cast the RHS data, we store it so that each value
822 // can be separately loaded in the accumulation loop.
823 _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
824
825 // NOTE: There may be opportunities for permuting the data in the packing
826 // code instead of here.
827 const __m256i lhs_data_split =
828 _mm256_shuffle_epi8(lhs_data, splitter_idx);
829 const __m256i lhs_data_split_expand_bottom =
830 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
831 const __m256i lhs_data_split_expand_top =
832 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
833
834 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
835 const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
836 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
837 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
838 const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
839 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
840 // Accumulate for column 0.
841 const std::int32_t low_rhs_value = rhs_data[0];
842 const std::int32_t high_rhs_value = rhs_data[1];
843
844 const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
845 const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
846
847 accum_data_v0 = _mm256_add_epi32(
848 accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
849 accum_data_v0 = _mm256_add_epi32(
850 accum_data_v0,
851 _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
852
853 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
854 rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
855 }
856
857 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
858 __m256i m_vector;
859 __m256i e_vector;
860 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
861 int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
862 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
863 params.multiplier_fixedpoint + channel));
864 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
865 params.multiplier_exponent + channel));
866
867 const __m256i m_64bit_low =
868 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
869 const __m256i m_64bit_high =
870 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
871
872 const __m256i zero_vector = _mm256_setzero_si256();
873 const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
874 const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
875 const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
876 const __m256i final_right_shift = _mm256_set1_epi32(31);
877 const __m256i final_right_shift_low =
878 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
879 const __m256i final_right_shift_high =
880 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
881 const __m256i convert_to_unsigned_64 =
882 _mm256_set1_epi64x(0x8000000000000000);
883
884 __m256i post_scaling_offset = _mm256_setzero_si256();
885 // A "half" added for rounding prior to truncation of 64-bit value.
886 const __m256i offset_vector = _mm256_add_epi64(
887 _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
888 convert_to_unsigned_64);
889
890 if (params.dst_zero_point) {
891 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
892 }
893
894 const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
895
896 // See GEMM version for details of this process.
897 {
898 __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
899 // Apply the fixed-point part of the multiplier.
900 __m256i scaled_v_low = _mm256_mul_epi32(
901 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
902 m_64bit_low);
903 __m256i scaled_v_high = _mm256_mul_epi32(
904 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
905 m_64bit_high);
906
907 scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
908 scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
909
910 scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
911 scaled_v_high =
912 _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
913
914 scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
915 __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
916 results = _mm256_permutevar8x32_epi32(results, repack_perm);
917
918 // Now do a Rounding Right Shift.
919 // First, construct the nudge value for each lane.
920 const __m256i zeros = _mm256_setzero_si256();
921 const __m256i mask_rightshift_gtz =
922 _mm256_cmpgt_epi32(right_shift, zeros);
923 const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
924 _mm256_set1_epi32(1),
925 _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
926 __m256i nudge = intrin_utils::mm256_blendv_epi32(
927 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
928 // Calculate the shifted sum (results + nudge) >> exp.
929 const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
930 const __m256i shifted_sum =
931 _mm256_srav_epi32(r_plus_nudge, right_shift);
932
933 // Identify overflow in each lane and create mask.
934 const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
935 _mm256_set1_epi32(1),
936 _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift));
937 const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
938 results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
939 // Fill results with either (results + nudge) >> exponent or
940 // 1 << (31 - exp) in the case of overflow.
941 results = intrin_utils::mm256_blendv_epi32(
942 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
943
944 accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset);
945 }
946 }
947 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
948 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
949
950 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
951 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
952 __m256i result = accum_data_v0;
953 result = _mm256_min_epi32(result, clamp_max_v);
954 result = _mm256_max_epi32(result, clamp_min_v);
955 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
956 result);
957 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
958 kAvx8bitBlockSize);
959 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
960 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
961 __m256i result = accum_data_v0;
962 result = _mm256_min_epi32(result, clamp_max_v);
963 result = _mm256_max_epi32(result, clamp_min_v);
964 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
965 result);
966 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
967 kAvx8bitBlockSize);
968 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
969 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
970 __m256i result = accum_data_v0;
971 result = _mm256_min_epi32(result, clamp_max_v);
972 result = _mm256_max_epi32(result, clamp_min_v);
973 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
974 result);
975 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
976 kAvx8bitBlockSize);
977 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
978 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
979 intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
980 accum_data_v0);
981 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
982 kAvx8bitBlockSize);
983 } else {
984 RUY_DCHECK(false);
985 }
986
987 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
988 } // End row-block loop.
989
990 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
991 kAvx8bitBlockSize * params.dst_stride);
992 rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
993 } // NOLINT(readability/fn_size)
994
995 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
996 Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params);
997 }
998
999 void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
1000 profiler::ScopeLabel label("Kernel kAvx2Fma float");
1001 KernelFloatAvxCommon<Path::kAvx2Fma>(params);
1002 }
1003
1004 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
1005 profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV");
1006 KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params);
1007 }
1008
1009 #endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
1010
1011 } // namespace ruy
1012