1 /* Copyright 2020 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cstdint>
18 #include <cstring>
19
20 #include "ruy/check_macros.h"
21 #include "ruy/kernel_common.h"
22 #include "ruy/kernel_x86.h"
23 #include "ruy/opt_set.h"
24 #include "ruy/platform.h"
25 #include "ruy/profiler/instrumentation.h"
26
27 #if RUY_PLATFORM_AVX && RUY_OPT(ASM)
28 #include <immintrin.h> // IWYU pragma: keep
29 #endif
30
31 namespace ruy {
32
33 #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
34
Kernel8bitAvx(const KernelParams8bit<8,8> &)35 void Kernel8bitAvx(const KernelParams8bit<8, 8>&) {
36 // CPU-ID-based checks should disable the path that would reach this point.
37 RUY_DCHECK(false);
38 }
39
Kernel8bitAvxSingleCol(const KernelParams8bit<8,8> &)40 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) {
41 // CPU-ID-based checks should disable the path that would reach this point.
42 RUY_DCHECK(false);
43 }
44
KernelFloatAvx(const KernelParamsFloat<8,8> &)45 void KernelFloatAvx(const KernelParamsFloat<8, 8>&) {
46 // CPU-ID-based checks should disable the path that would reach this point.
47 RUY_DCHECK(false);
48 }
49
KernelFloatAvxSingleCol(const KernelParamsFloat<8,8> &)50 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) {
51 // CPU-ID-based checks should disable the path that would reach this point.
52 RUY_DCHECK(false);
53 }
54
55 #else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
56
57 static constexpr int kAvx8bitBlockSize = 8;
58 static constexpr int kAvx8bitInnerSize = 4;
59
60 namespace {
61 namespace intrin_utils {
62
63 template <>
64 inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a,
65 const __m256i& b) {
66 __m128i a_lo = _mm256_extractf128_si256(a, 0);
67 __m128i a_hi = _mm256_extractf128_si256(a, 1);
68 __m128i b_lo = _mm256_extractf128_si256(b, 0);
69 __m128i b_hi = _mm256_extractf128_si256(b, 1);
70 __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo);
71 __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi);
72 return _mm256_set_m128i(dst_hi, dst_lo);
73 }
74
75 template <>
76 inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a,
77 const int imm) {
78 switch (imm) {
79 case 0:
80 return _mm256_extractf128_si256(a, 0);
81 case 1:
82 return _mm256_extractf128_si256(a, 1);
83 default:
84 RUY_DCHECK_LT(imm, 2);
85 return _mm_setzero_si128();
86 }
87 }
88
89 template <Path path>
90 inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
91 // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
92 __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
93 return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
94 }
95
96 template <Path path>
97 inline __m256i mm256_cvtepi32_epi64(const __m128i& a) {
98 // sign extend the 32-bit values in the lower 64 bits of a.
99 __m128i lo = _mm_cvtepi32_epi64(a);
100 __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128()));
101 return _mm256_set_m128i(hi, lo);
102 }
103
104 inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
105 const int imm) {
106 __m128i tmp = _mm_setzero_si128();
107 if (!(imm & 8)) {
108 switch (imm & 3) {
109 case 0:
110 return _mm256_extractf128_si256(a, 0);
111 case 1:
112 return _mm256_extractf128_si256(a, 1);
113 case 2:
114 return _mm256_extractf128_si256(b, 0);
115 case 3:
116 return _mm256_extractf128_si256(b, 1);
117 }
118 }
119 return tmp;
120 }
121
122 template <Path path>
123 inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
124 const int imm) {
125 const int lo_imm = imm & 15;
126 __m128i lo = mm_permute_helper(a, b, lo_imm);
127 const int hi_imm = (imm >> 4) & 15;
128 __m128i hi = mm_permute_helper(a, b, hi_imm);
129 return _mm256_set_m128i(hi, lo);
130 }
131
132 template <Path path>
133 inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) {
134 __m128i a_lo = _mm256_extractf128_si256(a, 0);
135 __m128i a_hi = _mm256_extractf128_si256(a, 1);
136 __m128i b_lo = _mm256_extractf128_si256(b, 0);
137 __m128i b_hi = _mm256_extractf128_si256(b, 1);
138 __m128i lo = _mm_max_epi32(a_lo, b_lo);
139 __m128i hi = _mm_max_epi32(a_hi, b_hi);
140 return _mm256_set_m128i(hi, lo);
141 }
142
143 template <Path path>
144 inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) {
145 __m128i a_lo = _mm256_extractf128_si256(a, 0);
146 __m128i a_hi = _mm256_extractf128_si256(a, 1);
147 __m128i b_lo = _mm256_extractf128_si256(b, 0);
148 __m128i b_hi = _mm256_extractf128_si256(b, 1);
149 __m128i lo = _mm_min_epi32(a_lo, b_lo);
150 __m128i hi = _mm_min_epi32(a_hi, b_hi);
151 return _mm256_set_m128i(hi, lo);
152 }
153
154 template <Path path>
155 inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
156 __m128i a_lo = _mm256_extractf128_si256(a, 0);
157 __m128i a_hi = _mm256_extractf128_si256(a, 1);
158 __m128i b_lo = _mm256_extractf128_si256(b, 0);
159 __m128i b_hi = _mm256_extractf128_si256(b, 1);
160 __m128i lo = _mm_add_epi32(a_lo, b_lo);
161 __m128i hi = _mm_add_epi32(a_hi, b_hi);
162 return _mm256_set_m128i(hi, lo);
163 }
164
165 template <Path path>
166 inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) {
167 __m128i a_lo = _mm256_extractf128_si256(a, 0);
168 __m128i a_hi = _mm256_extractf128_si256(a, 1);
169 __m128i b_lo = _mm256_extractf128_si256(b, 0);
170 __m128i b_hi = _mm256_extractf128_si256(b, 1);
171 __m128i lo = _mm_add_epi64(a_lo, b_lo);
172 __m128i hi = _mm_add_epi64(a_hi, b_hi);
173 return _mm256_set_m128i(hi, lo);
174 }
175
176 template <Path path>
177 inline __m256i mm256_slli_epi64(const __m256i& a, int imm) {
178 __m128i a_lo = _mm256_extractf128_si256(a, 0);
179 __m128i a_hi = _mm256_extractf128_si256(a, 1);
180 __m128i lo = _mm_slli_epi64(a_lo, imm);
181 __m128i hi = _mm_slli_epi64(a_hi, imm);
182 return _mm256_set_m128i(hi, lo);
183 }
184
185 template <Path path>
186 inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) {
187 __m128i a_lo = _mm256_extractf128_si256(a, 0);
188 __m128i a_hi = _mm256_extractf128_si256(a, 1);
189 __m128i b_lo = _mm256_extractf128_si256(b, 0);
190 __m128i b_hi = _mm256_extractf128_si256(b, 1);
191 __m128i lo = _mm_mullo_epi32(a_lo, b_lo);
192 __m128i hi = _mm_mullo_epi32(a_hi, b_hi);
193 return _mm256_set_m128i(hi, lo);
194 }
195
196 // Defined as a macro since `imm` must be an immediate.
197 #define BlendM128_epi32(a, b, imm) \
198 _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm))
199
200 // Defined as a macro since `imm` must be an immediate.
201 #define BlendM128_epi64(a, b, imm) \
202 _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm))
203
204 // Defined as a macro since `imm` must be an immediate.
205 #define mm256_blend_epi32(ans, a, b, imm) \
206 __m128i a_lo = _mm256_extractf128_si256(a, 0); \
207 __m128i a_hi = _mm256_extractf128_si256(a, 1); \
208 __m128i b_lo = _mm256_extractf128_si256(b, 0); \
209 __m128i b_hi = _mm256_extractf128_si256(b, 1); \
210 __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \
211 __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4); \
212 ans = _mm256_set_m128i(hi, lo);
213
214 #define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm) \
215 a_lo = _mm256_extractf128_si256(a, 0); \
216 a_hi = _mm256_extractf128_si256(a, 1); \
217 ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \
218 _mm_shuffle_epi32(a_lo, imm));
219
220 template <Path path>
221 inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
222 __m128i a_lo = _mm256_extractf128_si256(a, 0);
223 __m128i a_hi = _mm256_extractf128_si256(a, 1);
224 __m128i b_lo = _mm256_extractf128_si256(b, 0);
225 __m128i b_hi = _mm256_extractf128_si256(b, 1);
226 __m128i lo = _mm_madd_epi16(a_lo, b_lo);
227 __m128i hi = _mm_madd_epi16(a_hi, b_hi);
228 return _mm256_set_m128i(hi, lo);
229 }
230
231 inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) {
232 // shift both elements of a by lower 64bits of b.
233 __m128i res_lo = _mm_srl_epi64(a, b);
234 // shift both elements of a by upper 64bits of b.
235 __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
236 __m128i res_hi = _mm_srl_epi64(a, hi_count);
237 // Take the lower 64 bits of res_lo and upper 64 bits of res hi
238 // 1. Swap the upper and lower 64 bits of res_hi
239 __m128i tmp_hi =
240 _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
241 // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
242 return _mm_unpacklo_epi64(res_lo, tmp_hi);
243 }
244
245 template <Path path>
246 inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) {
247 __m128i a_lo = _mm256_extractf128_si256(a, 0);
248 __m128i a_hi = _mm256_extractf128_si256(a, 1);
249 __m128i b_lo = _mm256_extractf128_si256(b, 0);
250 __m128i b_hi = _mm256_extractf128_si256(b, 1);
251 __m128i lo = mm_srlv_epi64(a_lo, b_lo);
252 __m128i hi = mm_srlv_epi64(a_hi, b_hi);
253 return _mm256_set_m128i(hi, lo);
254 }
255
256 template <Path path>
257 inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) {
258 // shift both elements of a by lower 64bits of b.
259 __m128i res_lo = _mm_sll_epi64(a, b);
260 // shift both elements of a by upper 64bits of b.
261 __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
262 __m128i res_hi = _mm_sll_epi64(a, hi_count);
263 // Take the lower 64 bits of res_lo and upper 64 bits of res hi
264 // 1. Swap the upper and lower 64 bits of res_hi
265 __m128i tmp_hi =
266 _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
267 // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
268 return _mm_unpacklo_epi64(res_lo, tmp_hi);
269 }
270
271 template <Path path>
272 inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) {
273 __m128i a_lo = _mm256_extractf128_si256(a, 0);
274 __m128i a_hi = _mm256_extractf128_si256(a, 1);
275 __m128i b_lo = _mm256_extractf128_si256(b, 0);
276 __m128i b_hi = _mm256_extractf128_si256(b, 1);
277 __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo);
278 __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi);
279 return _mm256_set_m128i(hi, lo);
280 }
281
282 #define PermuteM128_epi32(a, imm) \
283 _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm));
284
285 inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) {
286 // shift all elements of a by first 32bits of b.
287 __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1));
288
289 // put bits 32-63 of b in the first slot.
290 __m128i tmp1 = PermuteM128_epi32(b, 1);
291 // put bits 32-63 of a in the first slot.
292 __m128i a1 = PermuteM128_epi32(a, 1);
293 // shift all elements of a by second 32bits of b.
294 __m128i res1 =
295 _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1));
296
297 // put bits 64-95 of b in the first slot.
298 __m128i tmp2 = PermuteM128_epi32(b, 2);
299 // shift all elements of a by third 32bits of b.
300 __m128i res2 =
301 _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1));
302
303 // put bits 96-127 of b in the first slot.
304 __m128i tmp3 = PermuteM128_epi32(b, 3);
305 // put bits 96-127 of a in the third slot.
306 __m128i a3 = PermuteM128_epi32(a, 48);
307 // shift all elements of a3 by fourth 32bits of b.
308 __m128i res3 =
309 _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1));
310
311 // Take bits 0-31 of res0, bits 0-31 of res1,
312 // bits 64-95 of res2, and bits 64-95 of res3.
313 // res0 _ _ _ 0
314 // res1 _ _ _ 1
315 // res2 _ 2 _ _
316 // res3 _ 3 _ _
317 // f_01 _ _ 1 0
318 // f_23 _ _ 3 2
319 __m128i f_01 = _mm_unpacklo_epi32(res0, res1);
320 __m128i f_23 = _mm_unpackhi_epi32(res2, res3);
321 // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
322 return _mm_unpacklo_epi64(f_01, f_23);
323 }
324
325 template <Path path>
326 inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) {
327 __m128i a_lo = _mm256_extractf128_si256(a, 0);
328 __m128i a_hi = _mm256_extractf128_si256(a, 1);
329 __m128i b_lo = _mm256_extractf128_si256(b, 0);
330 __m128i b_hi = _mm256_extractf128_si256(b, 1);
331 __m128i lo = mm_sllv_epi32(a_lo, b_lo);
332 __m128i hi = mm_sllv_epi32(a_hi, b_hi);
333 return _mm256_set_m128i(hi, lo);
334 }
335
336 template <Path path>
337 inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) {
338 __m128i a_lo = _mm256_extractf128_si256(a, 0);
339 __m128i a_hi = _mm256_extractf128_si256(a, 1);
340 __m128i b_lo = _mm256_extractf128_si256(b, 0);
341 __m128i b_hi = _mm256_extractf128_si256(b, 1);
342 __m128i lo = _mm_sub_epi32(a_lo, b_lo);
343 __m128i hi = _mm_sub_epi32(a_hi, b_hi);
344 return _mm256_set_m128i(hi, lo);
345 }
346
347 template <Path path>
348 inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) {
349 __m128i a_lo = _mm256_extractf128_si256(a, 0);
350 __m128i a_hi = _mm256_extractf128_si256(a, 1);
351 __m128i b_lo = _mm256_extractf128_si256(b, 0);
352 __m128i b_hi = _mm256_extractf128_si256(b, 1);
353 __m128i lo = _mm_mul_epi32(a_lo, b_lo);
354 __m128i hi = _mm_mul_epi32(a_hi, b_hi);
355 return _mm256_set_m128i(hi, lo);
356 }
357
358 // Perform the equivalent of mm256_permutevar8x32 with
359 // a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
360 template <Path path>
361 inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
362 // a_lo = 3 2 1 0
363 __m128i a_lo = _mm256_extractf128_si256(a, 0);
364 // a_hi = 7 6 5 4
365 __m128i a_hi = _mm256_extractf128_si256(a, 1);
366 // shuffle a_lo to get 3 1 2 0
367 __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
368 // shuffle a_hi to get 7 5 6 4
369 __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
370 // unpack lo 64 of res_lo and res hi to get 6 4 2 0
371 __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
372 // unpack hi 64 of res_lo and res hi to get 7 5 1 3
373 __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
374 return _mm256_set_m128i(res_hi, res_lo);
375 }
376
377 template <Path path>
378 inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) {
379 const __m256i bias0 = _mm256_set1_epi32(*(bias + offset));
380 return mm256_add_epi32<path>(a, bias0);
381 }
382
383 __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
384 const __m256i& mask) {
385 __m256 result =
386 _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
387 _mm256_castsi256_ps(mask));
388 return _mm256_castps_si256(result);
389 }
390
391 template <Path path>
392 inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) {
393 __m128i a_lo = _mm256_extractf128_si256(a, 0);
394 __m128i a_hi = _mm256_extractf128_si256(a, 1);
395 __m128i b_lo = _mm256_extractf128_si256(b, 0);
396 __m128i b_hi = _mm256_extractf128_si256(b, 1);
397 __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo);
398 __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi);
399 return _mm256_set_m128i(hi, lo);
400 }
401
402 template <Path path>
403 inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) {
404 __m128i a_lo = _mm256_extractf128_si256(a, 0);
405 __m128i a_hi = _mm256_extractf128_si256(a, 1);
406
407 __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0));
408 __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1));
409 __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2));
410 __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3));
411 __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4));
412 __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5));
413 __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6));
414 __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7));
415
416 // get element 0 from r0, element 1 from r1
417 __m128i r01 = BlendM128_epi32(r0, r1, 2);
418 // get element 2 from r2, element 3 from r3
419 __m128i r23 = BlendM128_epi32(r2, r3, 8);
420 // get element 0 from r4, element 1 from r5
421 __m128i r45 = BlendM128_epi32(r4, r5, 2);
422 // get element 2 from r6, element 3 from r7
423 __m128i r67 = BlendM128_epi32(r6, r7, 8);
424 // get lower 64 bits of r01, upper 64 bits of r23
425 __m128i r0123 = BlendM128_epi64(r01, r23, 2);
426 // get lower 64 bits of r45, upper 64 bits of r67
427 __m128i r4567 = BlendM128_epi64(r45, r67, 2);
428 return _mm256_set_m128i(r4567, r0123);
429 }
430
431 // AVX doesn't have fused multiply-add so we define an inline function to be
432 // used in the common code following.
433 template <>
434 inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b,
435 const __m256& c) {
436 const __m256 prod = _mm256_mul_ps(a, b);
437 return _mm256_add_ps(prod, c);
438 }
439
440 } // namespace intrin_utils
441 } // namespace
442
443 template <Path path>
444 void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
445 profiler::ScopeLabel label("Kernel kAvx 8-bit");
446 const std::int8_t splitter_idx_data[32] = {
447 0, 1, 4, 5, 8, 9, 12, 13, //
448 2, 3, 6, 7, 10, 11, 14, 15, //
449 0, 1, 4, 5, 8, 9, 12, 13, //
450 2, 3, 6, 7, 10, 11, 14, 15 //
451 };
452
453 std::int32_t dst_stride = 0;
454 if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
455 (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
456 dst_stride = params.dst_stride;
457 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
458 dst_stride = params.dst_stride / sizeof(std::int16_t);
459 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
460 dst_stride = params.dst_stride / sizeof(std::int32_t);
461 } else {
462 RUY_DCHECK(false);
463 }
464
465 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
466 void* dst_col_ptr = params.dst_base_ptr;
467
468 for (int col = params.start_col; col <= params.last_col;
469 col += kAvx8bitBlockSize) {
470 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
471 void* dst_ptr = dst_col_ptr;
472
473 const std::int32_t lhs_zero_point = params.lhs_zero_point;
474 const bool has_rhs_sums_offsets =
475 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
476 std::int32_t rhs_sums_offsets[8];
477 if (has_rhs_sums_offsets) {
478 const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
479 _mm256_set1_epi32(lhs_zero_point),
480 _mm256_loadu_si256(
481 reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col])));
482 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
483 rhs_sums_offset_v);
484 }
485
486 for (int row = params.start_row; row <= params.last_row;
487 row += kAvx8bitBlockSize) {
488 int channel =
489 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
490 int multiplier_channel =
491 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
492 const int residual_rows =
493 std::min(params.dst_rows - row, kAvx8bitBlockSize);
494 const int residual_cols =
495 std::min(params.dst_cols - col, kAvx8bitBlockSize);
496
497 const __m256i splitter_idx = _mm256_loadu_si256(
498 reinterpret_cast<__m256i const*>(splitter_idx_data));
499
500 __m256i accum_data_v0;
501 __m256i accum_data_v1;
502 __m256i accum_data_v2;
503 __m256i accum_data_v3;
504 __m256i accum_data_v4;
505 __m256i accum_data_v5;
506 __m256i accum_data_v6;
507 __m256i accum_data_v7;
508
509 // initial_accum_data will be the initialize of each of the
510 // accum_data_* accumulator registers. We compute into it terms that are
511 // identical across columns.
512 __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth);
513 __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth);
514
515 // In the channels-are-rows case, we can load bias here.
516 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
517 !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
518 initial_accum_data_lo = _mm_add_epi32(
519 initial_accum_data_lo,
520 _mm_loadu_si128(
521 reinterpret_cast<const __m128i*>(params.bias + row)));
522 initial_accum_data_hi = _mm_add_epi32(
523 initial_accum_data_hi,
524 _mm_loadu_si128(
525 reinterpret_cast<const __m128i*>(params.bias + row + 4)));
526 }
527
528 // Adjustments common across columns.
529 const std::int32_t rhs_zero_point = params.rhs_zero_point;
530 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
531 const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point);
532 const __m128i lhs_sums_offset_lo = _mm_mullo_epi32(
533 rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
534 ¶ms.lhs_sums[row])));
535 const __m128i lhs_sums_offset_hi = _mm_mullo_epi32(
536 rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
537 ¶ms.lhs_sums[row + 4])));
538
539 initial_accum_data_lo =
540 _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo);
541 initial_accum_data_hi =
542 _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi);
543 }
544
545 // Adjustments differing across columns.
546 if (has_rhs_sums_offsets) {
547 __m256i initial_accum_data =
548 _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
549
550 accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
551 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
552 accum_data_v1 = intrin_utils::mm256_sub_epi32<path>(
553 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
554 accum_data_v2 = intrin_utils::mm256_sub_epi32<path>(
555 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
556 accum_data_v3 = intrin_utils::mm256_sub_epi32<path>(
557 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
558 accum_data_v4 = intrin_utils::mm256_sub_epi32<path>(
559 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
560 accum_data_v5 = intrin_utils::mm256_sub_epi32<path>(
561 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
562 accum_data_v6 = intrin_utils::mm256_sub_epi32<path>(
563 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
564 accum_data_v7 = intrin_utils::mm256_sub_epi32<path>(
565 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
566 } else {
567 __m256i initial_accum_data =
568 _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
569 accum_data_v0 = initial_accum_data;
570 accum_data_v1 = initial_accum_data;
571 accum_data_v2 = initial_accum_data;
572 accum_data_v3 = initial_accum_data;
573 accum_data_v4 = initial_accum_data;
574 accum_data_v5 = initial_accum_data;
575 accum_data_v6 = initial_accum_data;
576 accum_data_v7 = initial_accum_data;
577 }
578
579 // Finally, in the channels-are-columns case, load bias data here.
580 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
581 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
582 accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0,
583 params.bias + col, 0);
584 accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1,
585 params.bias + col, 1);
586 accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2,
587 params.bias + col, 2);
588 accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3,
589 params.bias + col, 3);
590 accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4,
591 params.bias + col, 4);
592 accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5,
593 params.bias + col, 5);
594 accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6,
595 params.bias + col, 6);
596 accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7,
597 params.bias + col, 7);
598 }
599
600 const std::int8_t* lhs_ptr = lhs_col_ptr;
601 const std::int8_t* rhs_ptr = rhs_col_ptr;
602 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
603 const __m256i lhs_data =
604 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
605 const __m256i rhs_data_8bit =
606 _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
607
608 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
609 std::int32_t rhs_data[16];
610 const __m128i rhs_data_bottom_lane =
611 _mm256_castsi256_si128(rhs_data_8bit);
612 const __m128i rhs_data_top_lane =
613 _mm256_extractf128_si256(rhs_data_8bit, 1);
614 const __m256i rhs_16_bit_dup_low =
615 intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane);
616 const __m256i rhs_16_bit_dup_high =
617 intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane);
618 // Now that we have cast the RHS data, we store it so that each value
619 // can be separately loaded in the accumulation loop.
620 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
621 rhs_16_bit_dup_low);
622 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
623 rhs_16_bit_dup_high);
624
625 // NOTE: There may be opportunities for permuting the data in the
626 // packing code instead of here.
627 const __m256i lhs_data_split =
628 intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
629 const __m256i lhs_data_split_expand_bottom =
630 intrin_utils::mm256_cvtepi8_epi16<path>(
631 _mm256_extractf128_si256(lhs_data_split, 0));
632 const __m256i lhs_data_split_expand_top =
633 intrin_utils::mm256_cvtepi8_epi16<path>(
634 _mm256_extractf128_si256(lhs_data_split, 1));
635
636 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
637 const __m256i lhs_16_bit_low =
638 intrin_utils::mm256_permute2x128_si256<path>(
639 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
640 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
641 const __m256i lhs_16_bit_high =
642 intrin_utils::mm256_permute2x128_si256<path>(
643 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
644
645 __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
646 rhs_data)); // Load [0 1 2 3 4 5 6 7]
647 __m256i rhs1 = _mm256_lddqu_si256(
648 reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15]
649 __m256i rhs0_3 =
650 _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3]
651 __m256i rhs4_7 =
652 _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7]
653 __m256i rhs8_11 =
654 _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11]
655 __m256i rhs12_15 =
656 _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15]
657
658 auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
659 __m256i& accum) {
660 // Perform mul-adds on low and high components of accum separately.
661 __m128i accum_lo = _mm256_extractf128_si256(accum, 0);
662 __m128i accum_hi = _mm256_extractf128_si256(accum, 1);
663
664 __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0);
665 __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1);
666 __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0);
667 __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1);
668 __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0);
669 __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1);
670
671 accum_lo = _mm_add_epi32(accum_lo, lo_0);
672 accum_hi = _mm_add_epi32(accum_hi, lo_1);
673
674 __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0);
675 __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1);
676 __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0);
677 __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1);
678 __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0);
679 __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1);
680
681 accum_lo = _mm_add_epi32(accum_lo, hi_0);
682 accum_hi = _mm_add_epi32(accum_hi, hi_1);
683 accum = _mm256_set_m128i(accum_hi, accum_lo);
684 };
685 __m256i tmp0, tmp1, tmp2, tmp3;
686 __m128i lo0, lo1, hi0, hi1;
687 mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0);
688 mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55);
689 process_column(tmp0, tmp1, accum_data_v0);
690 mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa);
691 mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff);
692 process_column(tmp2, tmp3, accum_data_v1);
693
694 mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0);
695 mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55);
696 process_column(tmp0, tmp1, accum_data_v2);
697 mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa);
698 mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff);
699 process_column(tmp2, tmp3, accum_data_v3);
700
701 mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0);
702 mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55);
703 process_column(tmp0, tmp1, accum_data_v4);
704 mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa);
705 mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff);
706 process_column(tmp2, tmp3, accum_data_v5);
707
708 mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0);
709 mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55);
710 process_column(tmp0, tmp1, accum_data_v6);
711 mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa);
712 mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff);
713 process_column(tmp2, tmp3, accum_data_v7);
714
715 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
716 rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
717 }
718
719 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
720 __m256i m_vector;
721 __m256i e_vector;
722 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
723 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
724 params.multiplier_fixedpoint + multiplier_channel));
725 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
726 params.multiplier_exponent + multiplier_channel));
727
728 const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
729 _mm256_extractf128_si256(m_vector, 0));
730 const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
731 _mm256_extractf128_si256(m_vector, 1));
732
733 const __m256i zero_vector = _mm256_setzero_si256();
734 const __m256i left_shift =
735 intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
736 const __m256i neg_e_vector =
737 intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
738 const __m256i right_shift =
739 intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
740 const __m256i final_right_shift = _mm256_set1_epi32(31);
741 const __m256i final_right_shift_low =
742 intrin_utils::mm256_cvtepi32_epi64<path>(
743 _mm256_extractf128_si256(final_right_shift, 0));
744 const __m256i final_right_shift_high =
745 intrin_utils::mm256_cvtepi32_epi64<path>(
746 _mm256_extractf128_si256(final_right_shift, 1));
747 const __m256i convert_to_unsigned_64 =
748 _mm256_set1_epi64x(0x8000000000000000);
749
750 __m256i post_scaling_offset = _mm256_setzero_si256();
751
752 // A "half" added for rounding prior to truncation of 64-bit value.
753 const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
754 intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
755 convert_to_unsigned_64);
756
757 if (params.dst_zero_point) {
758 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
759 }
760
761 // We cannot do
762 //
763 // scaled_v_low =
764 // _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
765 // scaled_v_high =
766 // _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
767 //
768 // since this instruction is not in AVX2. Instead we use
769 // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
770 // offsets before (convert_to_unsigned_64) and after
771 // (convert_to_signed_halved).
772 //
773 // The overall process is, for 64-bit scaled accumulator:
774 // unsigned_accum = signed_accum + 1 << 63;
775 // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
776 // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
777
778 // There are various ways to repack the results, in the absence of
779 // _mm256_cvtepi64_epi32() or anything like it.
780 // A.
781 // accum_data_v[j] =
782 // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
783 // _mm256_extract_epi32(scaled_v_high, 4),
784 // _mm256_extract_epi32(scaled_v_high, 2),
785 // _mm256_extract_epi32(scaled_v_high, 0),
786 // _mm256_extract_epi32(scaled_v_low, 6),
787 // _mm256_extract_epi32(scaled_v_low, 4),
788 // _mm256_extract_epi32(scaled_v_low, 2),
789 // _mm256_extract_epi32(scaled_v_low, 0));
790 // B.
791 // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
792 // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
793 // accum_data_v[j] =
794 // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
795 // _mm256_extract_epi64(scaled_v_high, 0),
796 // _mm256_extract_epi64(scaled_v_low, 2),
797 // _mm256_extract_epi64(scaled_v_low, 0));
798 // C.
799 // scaled_v_low =
800 // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
801 // scaled_v_high =
802 // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
803 // accum_data_v[j] =
804 // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
805 //
806 // However, we choose the following because it uses two lighter
807 // instructions. The permutation does have a longer latency, but this
808 // loop can be unrolled.
809 // D.
810 // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
811 // __m256i results =
812 // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
813 // results = _mm256_permutevar8x32_epi32(results, repack_perm);
814 // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results,
815 // post_scaling_offset);
816
817 // This multiplier code is complex and expensive enough on x86, that
818 // we prefer to implement the channels-are-columns case by transposing
819 // around it, rather than duplicate it (which would also require
820 // duplicating the above code computing the multiplier constants).
821 // This is one instance where channels-are-columns has lower performance
822 // than channels-are-rows.
823 const bool transpose_around_multiplier =
824 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
825 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
826 if (transpose_around_multiplier) {
827 // Transpose the 8x8 accumulators block. Will be un-transposed below
828 // after the multplier implementation.
829 intrin_utils::mm256_transpose8x8_epi32<path>(
830 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
831 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
832 }
833
834 auto rounding_right_shift = [=](__m256i& results,
835 const __m256i& exponent) {
836 // Construct the "nudge" value for each lane if the exponent is
837 // greater than 0. Otherwise, the nudge is 0.
838 const __m256i zeros = _mm256_setzero_si256();
839 const __m256i mask_rightshift_gtz =
840 intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros);
841 const __m256i one_shift_exp_minus1 =
842 intrin_utils::mm256_sllv_epi32<path>(
843 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
844 exponent, _mm256_set1_epi32(1)));
845 __m256i nudge = intrin_utils::mm256_blendv_epi32(
846 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
847 // Calculate the shifted sum (results + nudge) >> exp.
848 const __m256i r_plus_nudge =
849 intrin_utils::mm256_add_epi32<path>(results, nudge);
850 const __m256i shifted_sum =
851 intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent);
852
853 // Identify overflow in each lane and create mask.
854 const __m256i one_shift_31minus_exp =
855 intrin_utils::mm256_sllv_epi32<path>(
856 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
857 _mm256_set1_epi32(31), exponent));
858 const __m256i mask_num_plus_nudge_overflow =
859 intrin_utils::mm256_cmpgt_epi32<path>(
860 results, intrin_utils::mm256_sub_epi32<path>(
861 _mm256_set1_epi32(0x7fffffff), nudge));
862 // Fill results with either (results + nudge) >> exponent or
863 // 1 << (31 - exp) in the case of overflow.
864 results = intrin_utils::mm256_blendv_epi32(
865 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
866 };
867
868 auto apply_multiplier = [=](__m256i& accum) {
869 __m256i shifted_accum =
870 intrin_utils::mm256_sllv_epi32<path>(accum, left_shift);
871 // Apply the fixed-point part of the multiplier.
872 __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
873 intrin_utils::mm256_cvtepi32_epi64<path>(
874 _mm256_extractf128_si256(shifted_accum, 0)),
875 m_64bit_low);
876 __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
877 intrin_utils::mm256_cvtepi32_epi64<path>(
878 _mm256_extractf128_si256(shifted_accum, 1)),
879 m_64bit_high);
880 scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
881 offset_vector);
882 scaled_v_high = intrin_utils::mm256_add_epi64<path>(
883 scaled_v_high, offset_vector);
884
885 scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
886 scaled_v_low, final_right_shift_low);
887 scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
888 scaled_v_high, final_right_shift_high);
889 scaled_v_high =
890 intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
891 __m256i results;
892 mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
893 // Permute results to this ordering of int32 elements
894 // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
895 results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
896
897 rounding_right_shift(results, right_shift);
898 accum =
899 intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
900 };
901 apply_multiplier(accum_data_v0);
902 apply_multiplier(accum_data_v1);
903 apply_multiplier(accum_data_v2);
904 apply_multiplier(accum_data_v3);
905 apply_multiplier(accum_data_v4);
906 apply_multiplier(accum_data_v5);
907 apply_multiplier(accum_data_v6);
908 apply_multiplier(accum_data_v7);
909 // See above comment: here we transpose again to undo the transposition
910 // of the 8x8 block of accumulators used to implement the
911 // channels-are-columns case.
912 if (transpose_around_multiplier) {
913 intrin_utils::mm256_transpose8x8_epi32<path>(
914 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
915 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
916 }
917 }
918 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
919 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
920 const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
921 (residual_cols == kAvx8bitBlockSize);
922
923 __m256i accum_data_v[kAvx8bitBlockSize];
924 if (!store_full_block) {
925 accum_data_v[0] = accum_data_v0;
926 accum_data_v[1] = accum_data_v1;
927 accum_data_v[2] = accum_data_v2;
928 accum_data_v[3] = accum_data_v3;
929 accum_data_v[4] = accum_data_v4;
930 accum_data_v[5] = accum_data_v5;
931 accum_data_v[6] = accum_data_v6;
932 accum_data_v[7] = accum_data_v7;
933 }
934
935 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
936 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
937 if (store_full_block) {
938 accum_data_v0 =
939 intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
940 accum_data_v0 =
941 intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
942 accum_data_v1 =
943 intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
944 accum_data_v1 =
945 intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
946 accum_data_v2 =
947 intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
948 accum_data_v2 =
949 intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
950 accum_data_v3 =
951 intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
952 accum_data_v3 =
953 intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
954 accum_data_v4 =
955 intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
956 accum_data_v4 =
957 intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
958 accum_data_v5 =
959 intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
960 accum_data_v5 =
961 intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
962 accum_data_v6 =
963 intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
964 accum_data_v6 =
965 intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
966 accum_data_v7 =
967 intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
968 accum_data_v7 =
969 intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
970 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
971 &tmp_ptr[0 * dst_stride], accum_data_v0);
972 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
973 &tmp_ptr[1 * dst_stride], accum_data_v1);
974 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
975 &tmp_ptr[2 * dst_stride], accum_data_v2);
976 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
977 &tmp_ptr[3 * dst_stride], accum_data_v3);
978 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
979 &tmp_ptr[4 * dst_stride], accum_data_v4);
980 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
981 &tmp_ptr[5 * dst_stride], accum_data_v5);
982 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
983 &tmp_ptr[6 * dst_stride], accum_data_v6);
984 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
985 &tmp_ptr[7 * dst_stride], accum_data_v7);
986 } else {
987 for (int j = 0; j < residual_cols; ++j) {
988 __m256i result = accum_data_v[j];
989 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
990 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
991 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
992 tmp_ptr, residual_rows, result);
993 tmp_ptr += dst_stride;
994 }
995 }
996 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
997 kAvx8bitBlockSize);
998 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
999 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1000 if (store_full_block) {
1001 accum_data_v0 =
1002 intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1003 accum_data_v0 =
1004 intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1005 accum_data_v1 =
1006 intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1007 accum_data_v1 =
1008 intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1009 accum_data_v2 =
1010 intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1011 accum_data_v2 =
1012 intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1013 accum_data_v3 =
1014 intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1015 accum_data_v3 =
1016 intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1017 accum_data_v4 =
1018 intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1019 accum_data_v4 =
1020 intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1021 accum_data_v5 =
1022 intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1023 accum_data_v5 =
1024 intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1025 accum_data_v6 =
1026 intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1027 accum_data_v6 =
1028 intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1029 accum_data_v7 =
1030 intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1031 accum_data_v7 =
1032 intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1033 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
1034 accum_data_v0);
1035 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
1036 accum_data_v1);
1037 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1038 &tmp_ptr[2 * dst_stride], accum_data_v2);
1039 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1040 &tmp_ptr[3 * dst_stride], accum_data_v3);
1041 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1042 &tmp_ptr[4 * dst_stride], accum_data_v4);
1043 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1044 &tmp_ptr[5 * dst_stride], accum_data_v5);
1045 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1046 &tmp_ptr[6 * dst_stride], accum_data_v6);
1047 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1048 &tmp_ptr[7 * dst_stride], accum_data_v7);
1049 } else {
1050 for (int j = 0; j < residual_cols; ++j) {
1051 __m256i result = accum_data_v[j];
1052 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1053 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1054 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
1055 tmp_ptr, residual_rows, result);
1056 tmp_ptr += dst_stride;
1057 }
1058 }
1059 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1060 kAvx8bitBlockSize);
1061 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1062 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1063 if (store_full_block) {
1064 accum_data_v0 =
1065 intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1066 accum_data_v0 =
1067 intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1068 accum_data_v1 =
1069 intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1070 accum_data_v1 =
1071 intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1072 accum_data_v2 =
1073 intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1074 accum_data_v2 =
1075 intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1076 accum_data_v3 =
1077 intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1078 accum_data_v3 =
1079 intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1080 accum_data_v4 =
1081 intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1082 accum_data_v4 =
1083 intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1084 accum_data_v5 =
1085 intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1086 accum_data_v5 =
1087 intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1088 accum_data_v6 =
1089 intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1090 accum_data_v6 =
1091 intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1092 accum_data_v7 =
1093 intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1094 accum_data_v7 =
1095 intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1096 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
1097 accum_data_v0);
1098 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
1099 accum_data_v1);
1100 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1101 &tmp_ptr[2 * dst_stride], accum_data_v2);
1102 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1103 &tmp_ptr[3 * dst_stride], accum_data_v3);
1104 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1105 &tmp_ptr[4 * dst_stride], accum_data_v4);
1106 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1107 &tmp_ptr[5 * dst_stride], accum_data_v5);
1108 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1109 &tmp_ptr[6 * dst_stride], accum_data_v6);
1110 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1111 &tmp_ptr[7 * dst_stride], accum_data_v7);
1112 } else {
1113 for (int j = 0; j < residual_cols; ++j) {
1114 __m256i result = accum_data_v[j];
1115 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1116 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1117 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
1118 tmp_ptr, residual_rows, result);
1119 tmp_ptr += dst_stride;
1120 }
1121 }
1122 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1123 kAvx8bitBlockSize);
1124 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1125 if (store_full_block) {
1126 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
1127 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
1128 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
1129 accum_data_v1);
1130 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
1131 accum_data_v2);
1132 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
1133 accum_data_v3);
1134 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
1135 accum_data_v4);
1136 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
1137 accum_data_v5);
1138 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
1139 accum_data_v6);
1140 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
1141 accum_data_v7);
1142 } else {
1143 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1144 for (int j = 0; j < residual_cols; ++j) {
1145 intrin_utils::mm256_n_storeu_epi32<path>(
1146 dst_block_ptr, residual_rows, accum_data_v[j]);
1147 dst_block_ptr += dst_stride;
1148 }
1149 }
1150 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1151 kAvx8bitBlockSize);
1152 } else {
1153 RUY_DCHECK(false);
1154 }
1155
1156 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1157 } // End row-block loop.
1158
1159 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1160 kAvx8bitBlockSize * params.dst_stride);
1161 rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1162 } // End col-block loop.
1163 } // NOLINT(readability/fn_size)
1164
1165 void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) {
1166 Kernel8bitAvxImpl<Path::kAvx>(params);
1167 }
1168
1169 template <Path path>
1170 void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
1171 profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV");
1172
1173 RUY_DCHECK_EQ(params.dst_cols, 1);
1174 RUY_DCHECK_EQ(params.last_col, 0);
1175 RUY_DCHECK_EQ(params.start_col, 0);
1176
1177 const std::int8_t splitter_idx_data[32] = {
1178 0, 1, 4, 5, 8, 9, 12, 13, //
1179 2, 3, 6, 7, 10, 11, 14, 15, //
1180 0, 1, 4, 5, 8, 9, 12, 13, //
1181 2, 3, 6, 7, 10, 11, 14, 15 //
1182 };
1183
1184 int bias_ptr_block_increment =
1185 params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
1186
1187 const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
1188 void* dst_col_ptr = params.dst_base_ptr;
1189 const std::int32_t* bias_col_ptr = params.bias;
1190 if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
1191 bias_col_ptr += params.start_row;
1192 }
1193
1194 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1195 void* dst_ptr = dst_col_ptr;
1196 const std::int32_t* bias_ptr = bias_col_ptr;
1197
1198 const std::int32_t lhs_zero_point = params.lhs_zero_point;
1199 const bool has_rhs_sums_offsets =
1200 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
1201 std::int32_t rhs_sums_offsets[8];
1202 if (has_rhs_sums_offsets) {
1203 const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
1204 _mm256_set1_epi32(lhs_zero_point),
1205 _mm256_loadu_si256(
1206 reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0])));
1207 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
1208 rhs_sums_offset_v);
1209 }
1210
1211 for (int row = params.start_row; row <= params.last_row;
1212 row += kAvx8bitBlockSize) {
1213 const int residual_rows =
1214 std::min(params.dst_rows - row, kAvx8bitBlockSize);
1215
1216 const __m256i splitter_idx =
1217 _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
1218
1219 __m256i accum_data_v0;
1220
1221 // Initialize with bias.
1222 __m256i initial_accum_data =
1223 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
1224 bias_ptr += bias_ptr_block_increment;
1225
1226 // Adjustments common across columns.
1227 const std::int32_t rhs_zero_point = params.rhs_zero_point;
1228 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
1229 const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>(
1230 _mm256_set1_epi32(rhs_zero_point),
1231 _mm256_loadu_si256(
1232 reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row])));
1233 initial_accum_data = intrin_utils::mm256_sub_epi32<path>(
1234 initial_accum_data, lhs_sums_offset);
1235 }
1236 const std::int32_t prod_zp_depth = params.prod_zp_depth;
1237 if (prod_zp_depth) {
1238 initial_accum_data = intrin_utils::mm256_add_epi32<path>(
1239 initial_accum_data, _mm256_set1_epi32(prod_zp_depth));
1240 }
1241
1242 // Adjustments differing across columns.
1243 if (has_rhs_sums_offsets) {
1244 accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
1245 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
1246 } else {
1247 accum_data_v0 = initial_accum_data;
1248 }
1249
1250 const std::int8_t* lhs_ptr = lhs_col_ptr;
1251 const std::int8_t* rhs_ptr = rhs_col_ptr;
1252 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
1253 const __m256i lhs_data =
1254 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
1255 const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
1256
1257 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
1258 // For simplicity we load 4x the data that we need and process twice the
1259 // data that we need and store only the data we need.
1260 std::int32_t rhs_data[2];
1261 const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
1262 // Now that we have cast the RHS data, we store it so that each value
1263 // can be separately loaded in the accumulation loop.
1264 _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
1265
1266 // NOTE: There may be opportunities for permuting the data in the packing
1267 // code instead of here.
1268 const __m256i lhs_data_split =
1269 intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
1270 const __m256i lhs_data_split_expand_bottom =
1271 intrin_utils::mm256_cvtepi8_epi16<path>(
1272 _mm256_extractf128_si256(lhs_data_split, 0));
1273 const __m256i lhs_data_split_expand_top =
1274 intrin_utils::mm256_cvtepi8_epi16<path>(
1275 _mm256_extractf128_si256(lhs_data_split, 1));
1276
1277 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
1278 const __m256i lhs_16_bit_low =
1279 intrin_utils::mm256_permute2x128_si256<path>(
1280 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
1281 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
1282 const __m256i lhs_16_bit_high =
1283 intrin_utils::mm256_permute2x128_si256<path>(
1284 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
1285 // Accumulate for column 0.
1286 const std::int32_t low_rhs_value = rhs_data[0];
1287 const std::int32_t high_rhs_value = rhs_data[1];
1288
1289 const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
1290 const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
1291
1292 accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1293 accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1294 lhs_16_bit_low, rhs_16_bit_dup_low));
1295 accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1296 accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1297 lhs_16_bit_high, rhs_16_bit_dup_high));
1298
1299 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1300 rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1301 }
1302
1303 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
1304 __m256i m_vector;
1305 __m256i e_vector;
1306 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
1307 int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
1308 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1309 params.multiplier_fixedpoint + channel));
1310 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1311 params.multiplier_exponent + channel));
1312
1313 const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
1314 _mm256_extractf128_si256(m_vector, 0));
1315 const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
1316 _mm256_extractf128_si256(m_vector, 1));
1317
1318 const __m256i zero_vector = _mm256_setzero_si256();
1319 const __m256i left_shift =
1320 intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
1321 const __m256i neg_e_vector =
1322 intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
1323 const __m256i right_shift =
1324 intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
1325 const __m256i final_right_shift = _mm256_set1_epi32(31);
1326 const __m256i final_right_shift_low =
1327 intrin_utils::mm256_cvtepi32_epi64<path>(
1328 _mm256_extractf128_si256(final_right_shift, 0));
1329 const __m256i final_right_shift_high =
1330 intrin_utils::mm256_cvtepi32_epi64<path>(
1331 _mm256_extractf128_si256(final_right_shift, 1));
1332 const __m256i convert_to_unsigned_64 =
1333 _mm256_set1_epi64x(0x8000000000000000);
1334
1335 __m256i post_scaling_offset = _mm256_setzero_si256();
1336
1337 // A "half" added for rounding prior to truncation of 64-bit value.
1338 const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
1339 intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
1340 convert_to_unsigned_64);
1341
1342 if (params.dst_zero_point) {
1343 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
1344 }
1345
1346 // See GEMM version for details of this process.
1347 {
1348 __m256i shifted_accum =
1349 intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift);
1350 // Apply the fixed-point part of the multiplier.
1351 __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
1352 intrin_utils::mm256_cvtepi32_epi64<path>(
1353 _mm256_extractf128_si256(shifted_accum, 0)),
1354 m_64bit_low);
1355 __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
1356 intrin_utils::mm256_cvtepi32_epi64<path>(
1357 _mm256_extractf128_si256(shifted_accum, 1)),
1358 m_64bit_high);
1359
1360 scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
1361 offset_vector);
1362 scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high,
1363 offset_vector);
1364
1365 scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
1366 scaled_v_low, final_right_shift_low);
1367 scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
1368 scaled_v_high, final_right_shift_high);
1369
1370 scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
1371 __m256i results;
1372 mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
1373 // Permute results to this ordering of int32 elements
1374 // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
1375 results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
1376
1377 // Now perform the Rounding Right Shift.
1378 // First, construct the "nudge" value for each lane if the exponent is
1379 // greater than 0. Otherwise, the nudge is 0.
1380 const __m256i zeros = _mm256_setzero_si256();
1381 const __m256i mask_rightshift_gtz =
1382 intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros);
1383 const __m256i one_shift_exp_minus1 =
1384 intrin_utils::mm256_sllv_epi32<path>(
1385 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1386 right_shift, _mm256_set1_epi32(1)));
1387 __m256i nudge = intrin_utils::mm256_blendv_epi32(
1388 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
1389 // Calculate the shifted sum (results + nudge) >> exp.
1390 const __m256i r_plus_nudge =
1391 intrin_utils::mm256_add_epi32<path>(results, nudge);
1392 const __m256i shifted_sum =
1393 intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift);
1394
1395 // Identify overflow in each lane and create mask.
1396 const __m256i one_shift_31minus_exp =
1397 intrin_utils::mm256_sllv_epi32<path>(
1398 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1399 _mm256_set1_epi32(31), right_shift));
1400 const __m256i mask_num_plus_nudge_overflow =
1401 intrin_utils::mm256_cmpgt_epi32<path>(
1402 results, intrin_utils::mm256_sub_epi32<path>(
1403 _mm256_set1_epi32(0x7fffffff), nudge));
1404 // Fill results with either (results + nudge) >> exponent or
1405 // 1 << (31 - exp) in the case of overflow.
1406 results = intrin_utils::mm256_blendv_epi32(
1407 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
1408 accum_data_v0 =
1409 intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
1410 }
1411 }
1412 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
1413 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
1414
1415 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
1416 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
1417 __m256i result = accum_data_v0;
1418 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1419 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1420 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1421 result);
1422 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
1423 kAvx8bitBlockSize);
1424 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
1425 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1426 __m256i result = accum_data_v0;
1427 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1428 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1429 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1430 result);
1431 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1432 kAvx8bitBlockSize);
1433 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1434 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1435 __m256i result = accum_data_v0;
1436 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1437 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1438 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
1439 result);
1440 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1441 kAvx8bitBlockSize);
1442 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1443 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1444 intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
1445 accum_data_v0);
1446 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1447 kAvx8bitBlockSize);
1448 } else {
1449 RUY_DCHECK(false);
1450 }
1451
1452 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1453 } // End row-block loop.
1454
1455 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1456 kAvx8bitBlockSize * params.dst_stride);
1457 rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1458 } // NOLINT(readability/fn_size)
1459
1460 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) {
1461 Kernel8bitAvxSingleColImpl<Path::kAvx>(params);
1462 }
1463
1464 void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) {
1465 profiler::ScopeLabel label("Kernel kAvx float");
1466 KernelFloatAvxCommon<Path::kAvx>(params);
1467 }
1468
1469 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) {
1470 profiler::ScopeLabel label("Kernel kAvx float GEMV");
1471 KernelFloatAvxCommonSingleCol<Path::kAvx>(params);
1472 }
1473
1474 #endif // RUY_PLATFORM_AVX && RUY_OPT(ASM)
1475
1476 } // namespace ruy
1477