1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <cstring>
18
19 #include "ruy/check_macros.h"
20 #include "ruy/opt_set.h"
21 #include "ruy/pack_x86.h"
22 #include "ruy/path.h"
23 #include "ruy/platform.h"
24 #include "ruy/profiler/instrumentation.h"
25
26 #if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
27 #include <immintrin.h> // IWYU pragma: keep
28 #endif
29
30 namespace ruy {
31
32 #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
33
Pack8bitColMajorForAvx512(const std::int8_t *,std::int8_t,const std::int8_t *,int,int,int,std::int8_t *,std::int32_t *)34 void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t,
35 const std::int8_t*, int, int, int, std::int8_t*,
36 std::int32_t*) {
37 // CPU-ID-based checks should disable the path that would reach this point.
38 RUY_DCHECK(false);
39 }
40
PackFloatColMajorForAvx512(const float *,const float *,int,int,int,float *)41 void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
42 float*) {
43 // CPU-ID-based checks should disable the path that would reach this point.
44 RUY_DCHECK(false);
45 }
46
Pack8bitRowMajorForAvx512(const std::uint8_t *,int,int,std::int8_t *,int,int,int,int,int,int,int,std::int32_t *)47 void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
48 int, int, int, int, int, int, std::int32_t*) {
49 RUY_DCHECK(false);
50 }
51
52 #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
53
54 // The first int8_t template parameter is arbitrary: this routine is common to
55 // all 8-bit source matrix types.
56 using PackImpl8bitAvx512 =
57 PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
58 std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
59
60 namespace {
61
62 inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
63 std::int8_t* packed_ptr) {
64 using Layout = PackImpl8bitAvx512::Layout;
65 static constexpr int kHalfLayoutCols =
66 PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
67 // block.
68 RUY_DCHECK_EQ(kHalfLayoutCols, 8);
69 RUY_DCHECK_EQ(Layout::kCols, 16);
70 RUY_DCHECK_EQ(Layout::kRows, 4);
71
72 const int non_trailing_blocks = (src_rows & ~31) >> 2;
73 // This routine fills half blocks, and typically fills the second halves.
74 // Thus packed_ptr is already offset by 8 * 4.
75 for (int k = 0; k < non_trailing_blocks; ++k) {
76 for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
77 packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
78 }
79 }
80 }
81
82 inline __m512i LoaduTwo(const std::int8_t* addr_lo,
83 const std::int8_t* addr_hi) {
84 __m512i lower_filled = _mm512_castsi256_si512(
85 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
86 return _mm512_inserti32x8(
87 lower_filled,
88 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1);
89 }
90
91 inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
92 const std::int8_t* addr_lo,
93 const std::int8_t* addr_hi) {
94 const __m512i lower_filled = _mm512_castsi256_si512(
95 _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo));
96 return _mm512_inserti32x8(
97 lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi),
98 1);
99 }
100
101 inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
102 std::int8_t input_xor,
103 const std::int8_t* zerobuf, int src_stride,
104 int remaining_src_cols, int src_rows,
105 std::int8_t* packed_ptr, std::int32_t* sums_ptr,
106 std::int8_t* trailing_buf) {
107 using Layout = PackImpl8bitAvx512::Layout;
108 RUY_DCHECK_EQ(Layout::kCols, 16);
109 RUY_DCHECK_EQ(Layout::kRows, 4);
110 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
111 // We process 8 of these chunks at a time, padding short input chunks.
112 constexpr int kNumRowChunks = 8;
113 constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
114
115 const std::int8_t* src_ptr0 = src_ptr;
116 const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
117 const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
118 const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
119 const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
120 const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
121 const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
122 const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
123 std::int64_t src_inc0 = kNumChunkedSrcRows;
124 std::int64_t src_inc1 = kNumChunkedSrcRows;
125 std::int64_t src_inc2 = kNumChunkedSrcRows;
126 std::int64_t src_inc3 = kNumChunkedSrcRows;
127 std::int64_t src_inc4 = kNumChunkedSrcRows;
128 std::int64_t src_inc5 = kNumChunkedSrcRows;
129 std::int64_t src_inc6 = kNumChunkedSrcRows;
130 std::int64_t src_inc7 = kNumChunkedSrcRows;
131 // Handle cases where source does not have kHalfLayoutCols (8) columns.
132 if (remaining_src_cols < 8) {
133 if (remaining_src_cols <= 0) {
134 src_ptr0 = zerobuf;
135 src_inc0 = 0;
136 }
137 if (remaining_src_cols <= 1) {
138 src_ptr1 = zerobuf;
139 src_inc1 = 0;
140 }
141 if (remaining_src_cols <= 2) {
142 src_ptr2 = zerobuf;
143 src_inc2 = 0;
144 }
145 if (remaining_src_cols <= 3) {
146 src_ptr3 = zerobuf;
147 src_inc3 = 0;
148 }
149 if (remaining_src_cols <= 4) {
150 src_ptr4 = zerobuf;
151 src_inc4 = 0;
152 }
153 if (remaining_src_cols <= 5) {
154 src_ptr5 = zerobuf;
155 src_inc5 = 0;
156 }
157 if (remaining_src_cols <= 6) {
158 src_ptr6 = zerobuf;
159 src_inc6 = 0;
160 }
161 src_ptr7 = zerobuf;
162 src_inc7 = 0;
163 }
164
165 const std::int8_t zero_point = zerobuf[0];
166
167 if (sums_ptr) {
168 // i: kHalfLayoutCols.
169 for (int i = 0; i < 8; ++i) {
170 sums_ptr[i] = 0;
171 }
172 }
173 std::int32_t sums_adjustment = 0;
174 const __m512i ones_16bit = _mm512_set1_epi16(1);
175 __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
176
177 // The overall packing effectively pads the source rows to
178 // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
179 // only pack for (src_rows + 31) & ~31. When there is an incomplete
180 // destination block, this is stored into trailing_buf instead of packed_ptr.
181 for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
182 // m: {0, 1} for 2 chunks of rows.
183 for (int m = 0; m < 2; ++m) {
184 // Available source rows.
185 // If this is less than 0 (for m=1), we skip, having filled trailing
186 // buffer for m=0. Also, if source rows is zero on m=1, then we filled
187 // exactly to the end of the column in the packed buffer.
188 const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
189 // Effectively,
190 // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
191 // treat each case separately.
192 if (available_src_rows >= kNumChunkedSrcRows) {
193 // i: chunks, s: Layout::Rows.
194 if (sums_ptr) {
195 __m512i t0, t1, t2, t3;
196 __m512i r0, r1, r2, r3;
197 const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
198
199 t0 = LoaduTwo(src_ptr0, src_ptr4);
200 t1 = LoaduTwo(src_ptr1, src_ptr5);
201 t2 = LoaduTwo(src_ptr2, src_ptr6);
202 t3 = LoaduTwo(src_ptr3, src_ptr7);
203
204 r0 = _mm512_unpacklo_epi32(t0, t1);
205 r2 = _mm512_unpackhi_epi32(t0, t1);
206 r1 = _mm512_unpacklo_epi32(t2, t3);
207 r3 = _mm512_unpackhi_epi32(t2, t3);
208
209 t0 = _mm512_unpacklo_epi64(r0, r1);
210 t2 = _mm512_unpackhi_epi64(r0, r1);
211 t1 = _mm512_unpacklo_epi64(r2, r3);
212 t3 = _mm512_unpackhi_epi64(r2, r3);
213
214 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
215 r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
216 r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
217 r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
218
219 r0 = _mm512_xor_si512(r0, input_xor_v);
220 r1 = _mm512_xor_si512(r1, input_xor_v);
221 r2 = _mm512_xor_si512(r2, input_xor_v);
222 r3 = _mm512_xor_si512(r3, input_xor_v);
223
224 const __m256i r0_0 = _mm512_castsi512_si256(r0);
225 const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
226 const __m256i r1_0 = _mm512_castsi512_si256(r1);
227 const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
228 const __m256i r2_0 = _mm512_castsi512_si256(r2);
229 const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
230 const __m256i r3_0 = _mm512_castsi512_si256(r3);
231 const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
232
233 __m512i sums_8x4_16bit;
234 sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
235 sums_8x4_16bit =
236 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
237 sums_8x4_16bit =
238 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
239 sums_8x4_16bit =
240 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
241 sums_8x4_16bit =
242 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
243 sums_8x4_16bit =
244 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
245 sums_8x4_16bit =
246 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
247 sums_8x4_16bit =
248 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
249 // The sums have been performed across columns, and now we have
250 // 4x16-bit sums packed together. We use madd for pairwise 32-bit
251 // sums.
252 const __m512i sums_8x2_32bit_new =
253 _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
254 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
255
256 _mm256_storeu_si256(
257 reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
258 _mm256_storeu_si256(
259 reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
260 _mm256_storeu_si256(
261 reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
262 _mm256_storeu_si256(
263 reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
264 _mm256_storeu_si256(
265 reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
266 _mm256_storeu_si256(
267 reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
268 _mm256_storeu_si256(
269 reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
270 _mm256_storeu_si256(
271 reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
272 } else {
273 __m512i t0, t1, t2, t3;
274 __m512i r0, r1, r2, r3;
275 const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
276
277 t0 = LoaduTwo(src_ptr0, src_ptr4);
278 t1 = LoaduTwo(src_ptr1, src_ptr5);
279 t2 = LoaduTwo(src_ptr2, src_ptr6);
280 t3 = LoaduTwo(src_ptr3, src_ptr7);
281
282 r0 = _mm512_unpacklo_epi32(t0, t1);
283 r2 = _mm512_unpackhi_epi32(t0, t1);
284 r1 = _mm512_unpacklo_epi32(t2, t3);
285 r3 = _mm512_unpackhi_epi32(t2, t3);
286
287 t0 = _mm512_unpacklo_epi64(r0, r1);
288 t2 = _mm512_unpackhi_epi64(r0, r1);
289 t1 = _mm512_unpacklo_epi64(r2, r3);
290 t3 = _mm512_unpackhi_epi64(r2, r3);
291
292 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
293 r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
294 r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
295 r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
296
297 r0 = _mm512_xor_si512(r0, input_xor_v);
298 r1 = _mm512_xor_si512(r1, input_xor_v);
299 r2 = _mm512_xor_si512(r2, input_xor_v);
300 r3 = _mm512_xor_si512(r3, input_xor_v);
301
302 const __m256i r0_0 = _mm512_castsi512_si256(r0);
303 const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
304 const __m256i r1_0 = _mm512_castsi512_si256(r1);
305 const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
306 const __m256i r2_0 = _mm512_castsi512_si256(r2);
307 const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
308 const __m256i r3_0 = _mm512_castsi512_si256(r3);
309 const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
310 _mm256_storeu_si256(
311 reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
312 _mm256_storeu_si256(
313 reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
314 _mm256_storeu_si256(
315 reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
316 _mm256_storeu_si256(
317 reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
318 _mm256_storeu_si256(
319 reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
320 _mm256_storeu_si256(
321 reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
322 _mm256_storeu_si256(
323 reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
324 _mm256_storeu_si256(
325 reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
326 }
327 } else if (available_src_rows > 0) {
328 RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
329 const __mmask32 row_mask =
330 (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
331
332 // We do not care what goes into the trailing buffer, but we want
333 // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
334 //
335 // We compensate for padding-with-zero_point by initializing the
336 // summations with the compensating offset, effectively
337 // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
338 // 4 * (8 - ((available_src_rows + 3) >> 2)).
339 //
340 // Note that (zero_point ^ input_xor) is performed in 8-bits and then
341 // cast.
342 sums_adjustment += -(zero_point ^ input_xor) * 4 *
343 (8 - ((available_src_rows + 3) >> 2));
344
345 __m512i t0, t1, t2, t3;
346 __m512i r0, r1, r2, r3;
347 const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
348 const __m256i zero_point_v = _mm256_set1_epi8(zero_point);
349
350 t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
351 t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
352 t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
353 t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
354
355 r0 = _mm512_unpacklo_epi32(t0, t1);
356 r2 = _mm512_unpackhi_epi32(t0, t1);
357 r1 = _mm512_unpacklo_epi32(t2, t3);
358 r3 = _mm512_unpackhi_epi32(t2, t3);
359
360 t0 = _mm512_unpacklo_epi64(r0, r1);
361 t2 = _mm512_unpackhi_epi64(r0, r1);
362 t1 = _mm512_unpacklo_epi64(r2, r3);
363 t3 = _mm512_unpackhi_epi64(r2, r3);
364
365 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
366 r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
367 r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
368 r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
369
370 r0 = _mm512_xor_si512(r0, input_xor_v);
371 r1 = _mm512_xor_si512(r1, input_xor_v);
372 r2 = _mm512_xor_si512(r2, input_xor_v);
373 r3 = _mm512_xor_si512(r3, input_xor_v);
374
375 const __m256i r0_0 = _mm512_castsi512_si256(r0);
376 const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
377 const __m256i r1_0 = _mm512_castsi512_si256(r1);
378 const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
379 const __m256i r2_0 = _mm512_castsi512_si256(r2);
380 const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
381 const __m256i r3_0 = _mm512_castsi512_si256(r3);
382 const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
383
384 __m512i sums_8x4_16bit;
385 sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
386 sums_8x4_16bit =
387 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
388 sums_8x4_16bit =
389 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
390 sums_8x4_16bit =
391 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
392 sums_8x4_16bit =
393 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
394 sums_8x4_16bit =
395 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
396 sums_8x4_16bit =
397 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
398 sums_8x4_16bit =
399 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
400 // The sums have been performed across columns, and now we have
401 // 4x16-bit sums packed together. We use madd for pairwise 32-bit
402 // sums.
403 const __m512i sums_8x2_32bit_new =
404 _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
405 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
406
407 _mm256_storeu_si256(
408 reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0);
409 _mm256_storeu_si256(
410 reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1);
411 _mm256_storeu_si256(
412 reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0);
413 _mm256_storeu_si256(
414 reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1);
415 _mm256_storeu_si256(
416 reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0);
417 _mm256_storeu_si256(
418 reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1);
419 _mm256_storeu_si256(
420 reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0);
421 _mm256_storeu_si256(
422 reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1);
423 }
424
425 packed_ptr += 16 * kNumChunkedSrcRows;
426 src_ptr0 += src_inc0;
427 src_ptr1 += src_inc1;
428 src_ptr2 += src_inc2;
429 src_ptr3 += src_inc3;
430 src_ptr4 += src_inc4;
431 src_ptr5 += src_inc5;
432 src_ptr6 += src_inc6;
433 src_ptr7 += src_inc7;
434 }
435 }
436
437 if (sums_ptr) {
438 const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
439
440 __m256i sums =
441 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
442 const __m512i idx =
443 _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
444
445 // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
446 // neighbours, finshing up by adding them to the stored accumulated sums.
447 const __m512i sums_2x8_32bit =
448 _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
449 sums = _mm256_add_epi32(sums, sums_adjustment_v);
450 sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
451 sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
452
453 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
454 }
455 }
456
457 inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
458 const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
459 return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
460 }
461
462 inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
463 const float* addr_hi) {
464 const __m512 lower_filled =
465 _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
466 return _mm512_insertf32x8(lower_filled,
467 _mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
468 }
469
470 inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
471 return _mm512_castpd_ps(
472 _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
473 }
474
475 inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
476 return _mm512_castpd_ps(
477 _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
478 }
479
480 inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
481 int src_stride, int remaining_src_cols,
482 int src_rows, float* packed_ptr,
483 float* trailing_buf) {
484 const float* src_ptr0 = src_ptr;
485 const float* src_ptr1 = src_ptr0 + src_stride;
486 const float* src_ptr2 = src_ptr1 + src_stride;
487 const float* src_ptr3 = src_ptr2 + src_stride;
488 const float* src_ptr4 = src_ptr3 + src_stride;
489 const float* src_ptr5 = src_ptr4 + src_stride;
490 const float* src_ptr6 = src_ptr5 + src_stride;
491 const float* src_ptr7 = src_ptr6 + src_stride;
492 std::int64_t src_inc0 = 8;
493 std::int64_t src_inc1 = 8;
494 std::int64_t src_inc2 = 8;
495 std::int64_t src_inc3 = 8;
496 std::int64_t src_inc4 = 8;
497 std::int64_t src_inc5 = 8;
498 std::int64_t src_inc6 = 8;
499 std::int64_t src_inc7 = 8;
500 if (remaining_src_cols < 8) {
501 if (remaining_src_cols <= 0) {
502 src_ptr0 = zerobuf;
503 src_inc0 = 0;
504 }
505 if (remaining_src_cols <= 1) {
506 src_ptr1 = zerobuf;
507 src_inc1 = 0;
508 }
509 if (remaining_src_cols <= 2) {
510 src_ptr2 = zerobuf;
511 src_inc2 = 0;
512 }
513 if (remaining_src_cols <= 3) {
514 src_ptr3 = zerobuf;
515 src_inc3 = 0;
516 }
517 if (remaining_src_cols <= 4) {
518 src_ptr4 = zerobuf;
519 src_inc4 = 0;
520 }
521 if (remaining_src_cols <= 5) {
522 src_ptr5 = zerobuf;
523 src_inc5 = 0;
524 }
525 if (remaining_src_cols <= 6) {
526 src_ptr6 = zerobuf;
527 src_inc6 = 0;
528 }
529 src_ptr7 = zerobuf;
530 src_inc7 = 0;
531 }
532
533 for (int k = 0; k < src_rows; k += 16) {
534 for (int m = 0; m < 2; ++m) {
535 const int available_src_rows = src_rows - k - 8 * m;
536 // Effectively,
537 // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
538 // but treat each case separately.
539 if (available_src_rows > 7) {
540 __m512 t0, t1, t2, t3;
541 __m512 r0, r1, r2, r3;
542
543 t0 = LoaduTwo(src_ptr0, src_ptr4);
544 t1 = LoaduTwo(src_ptr1, src_ptr5);
545 t2 = LoaduTwo(src_ptr2, src_ptr6);
546 t3 = LoaduTwo(src_ptr3, src_ptr7);
547
548 r0 = _mm512_unpacklo_ps(t0, t1);
549 r2 = _mm512_unpackhi_ps(t0, t1);
550 r1 = _mm512_unpacklo_ps(t2, t3);
551 r3 = _mm512_unpackhi_ps(t2, t3);
552
553 t0 = Mm512UnpackloPsx2(r0, r1);
554 t2 = Mm512UnpackhiPsx2(r0, r1);
555 t1 = Mm512UnpackloPsx2(r2, r3);
556 t3 = Mm512UnpackhiPsx2(r2, r3);
557
558 r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
559 r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
560 r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
561 r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
562
563 _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
564 _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
565 _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
566 _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
567 _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
568 _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
569 _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
570 _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
571 } else if (available_src_rows > 0) {
572 const __mmask8 row_mask =
573 (static_cast<std::uint32_t>(1) << available_src_rows) - 1;
574
575 __m512 t0, t1, t2, t3;
576 __m512 r0, r1, r2, r3;
577
578 t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
579 t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
580 t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
581 t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
582
583 r0 = _mm512_unpacklo_ps(t0, t1);
584 r2 = _mm512_unpackhi_ps(t0, t1);
585 r1 = _mm512_unpacklo_ps(t2, t3);
586 r3 = _mm512_unpackhi_ps(t2, t3);
587
588 t0 = Mm512UnpackloPsx2(r0, r1);
589 t2 = Mm512UnpackhiPsx2(r0, r1);
590 t1 = Mm512UnpackloPsx2(r2, r3);
591 t3 = Mm512UnpackhiPsx2(r2, r3);
592
593 r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
594 r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
595 r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
596 r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
597
598 _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
599 _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
600 _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
601 _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
602 _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
603 _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
604 _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
605 // Do not store _mm512_extractf32x8_ps(r3, 1).
606 }
607
608 packed_ptr += 16 * 8;
609 src_ptr0 += src_inc0;
610 src_ptr1 += src_inc1;
611 src_ptr2 += src_inc2;
612 src_ptr3 += src_inc3;
613 src_ptr4 += src_inc4;
614 src_ptr5 += src_inc5;
615 src_ptr6 += src_inc6;
616 src_ptr7 += src_inc7;
617 }
618 }
619 }
620
621 inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
622 const int non_trailing_rows = src_rows & ~7;
623 for (int k = 0; k < non_trailing_rows; ++k) {
624 for (int j = 0; j < 8; ++j) {
625 packed_ptr[j] = 0.0f;
626 }
627 packed_ptr += 16;
628 }
629 }
630
631 } // namespace.
632
633 void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
634 std::int8_t input_xor,
635 const std::int8_t* zerobuf, int src_stride,
636 int remaining_src_cols, int src_rows,
637 std::int8_t* packed_ptr,
638 std::int32_t* sums_ptr) {
639 profiler::ScopeLabel label("Pack kAvx512 8bit");
640
641 using Layout = PackImpl8bitAvx512::Layout;
642 constexpr int kHalfBlockOffset = 32;
643 RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
644 static constexpr int kHalfLayoutCols =
645 PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
646 // block.
647 RUY_DCHECK_EQ(kHalfLayoutCols, 8);
648 RUY_DCHECK_EQ(Layout::kCols, 16);
649 RUY_DCHECK_EQ(Layout::kRows, 4);
650
651 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
652 // We process 8 of these chunks at a time, padding short input chunks.
653 constexpr int kNumRowChunks = 8;
654
655 // Each packed block is 4*16, and there are normally 8. The trailing block is
656 // only slightly shorter.
657 constexpr int kTrailingBufSize =
658 kNumRowChunks * Layout::kCols * Layout::kRows;
659 std::int8_t trailing_buf[kTrailingBufSize];
660 memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
661
662 std::int32_t* second_sums_ptr =
663 sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
664 if (remaining_src_cols > kHalfLayoutCols) {
665 HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
666 remaining_src_cols, src_rows, packed_ptr, sums_ptr,
667 trailing_buf);
668 HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
669 zerobuf, src_stride,
670 remaining_src_cols - kHalfLayoutCols, src_rows,
671 packed_ptr + kHalfBlockOffset, second_sums_ptr,
672 trailing_buf + kHalfBlockOffset);
673 } else {
674 HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
675 remaining_src_cols, src_rows, packed_ptr, sums_ptr,
676 trailing_buf);
677 ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
678 packed_ptr + kHalfBlockOffset);
679 // The kernel may not need the second half-blocks sums to be set.
680 if (second_sums_ptr) {
681 for (int i = 0; i < kHalfLayoutCols; ++i) {
682 second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
683 }
684 }
685 }
686 constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
687 const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
688 // If the number of source rows is not a multiple of kChunkedRowMask, there
689 // will be data in the trailing buffer,
690 if (trailing_data) {
691 const int non_trailing_rows = src_rows & ~kChunkedRowMask;
692 // Destination "rows" are padded to next highest multiple of Layout::kRows.
693 const int dst_rows = (src_rows + 3) & ~3;
694 const int trailing_rows = dst_rows - non_trailing_rows;
695 memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
696 Layout::kCols * trailing_rows * sizeof(std::int8_t));
697 }
698 }
699
700 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
701 int src_stride, int remaining_src_cols,
702 int src_rows, float* packed_ptr) {
703 profiler::ScopeLabel label("Pack kAvx512 float");
704 float trailing_buf[7 * 16];
705 if (remaining_src_cols > 8) {
706 HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
707 src_rows, packed_ptr, trailing_buf);
708 HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
709 remaining_src_cols - 8, src_rows, packed_ptr + 8,
710 trailing_buf + 8);
711 } else {
712 memset(trailing_buf, 0, sizeof(trailing_buf));
713 HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
714 src_rows, packed_ptr, trailing_buf);
715 ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
716 }
717 const int trailing_rows = src_rows & 7;
718 if (trailing_rows > 0) {
719 const int non_trailing_rows = src_rows & ~7;
720 memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
721 16 * trailing_rows * sizeof(float));
722 }
723 }
724
725 void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
726 int src_zero_point, std::int8_t* packed_ptr,
727 int packed_stride, int start_col, int end_col,
728 int src_cols, int block_row, int src_rows,
729 int input_xor, std::int32_t* sums) {
730 int col = start_col;
731 int src_end_col = std::min(end_col, src_cols);
732
733 for (; col <= src_end_col - 16; col += 16) {
734 std::int8_t* dst_ptr = packed_ptr;
735 __m128i val0, val1, val2, val3;
736 __m128i input_xor_dup = _mm_set1_epi8(input_xor);
737 // Load a 4x16 block.
738 if (block_row + 4 <= src_rows) {
739 val0 = _mm_loadu_si128(
740 reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
741 val1 = _mm_loadu_si128(
742 reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
743 val2 = _mm_loadu_si128(
744 reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
745 val3 = _mm_loadu_si128(
746 reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
747 } else {
748 val0 = _mm_set1_epi8(src_zero_point);
749 val1 = val0;
750 val2 = val0;
751 val3 = val0;
752 if (block_row + 0 < src_rows)
753 val0 = _mm_loadu_si128(
754 reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
755 if (block_row + 1 < src_rows)
756 val1 = _mm_loadu_si128(
757 reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
758 if (block_row + 2 < src_rows)
759 val2 = _mm_loadu_si128(
760 reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
761 if (block_row + 3 < src_rows)
762 val3 = _mm_loadu_si128(
763 reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
764 }
765 // Maybe xor the sign bit to convert from uint8 to int8.
766 val0 = _mm_xor_si128(val0, input_xor_dup);
767 val1 = _mm_xor_si128(val1, input_xor_dup);
768 val2 = _mm_xor_si128(val2, input_xor_dup);
769 val3 = _mm_xor_si128(val3, input_xor_dup);
770 // Update the sums.
771 __m256i val16_0 = _mm256_cvtepi8_epi16(val0);
772 __m256i val16_1 = _mm256_cvtepi8_epi16(val1);
773 __m256i val16_2 = _mm256_cvtepi8_epi16(val2);
774 __m256i val16_3 = _mm256_cvtepi8_epi16(val3);
775 __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
776 _mm256_add_epi16(val16_2, val16_3));
777 __m512i sum =
778 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
779 sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
780 _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
781 auto zip = [](__m128i x, __m128i y) {
782 auto perm_64_0_64_0 = [](__m128i x) {
783 return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
784 _mm256_castsi128_si256(x));
785 };
786 return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
787 };
788 __m256i t2_val0 = zip(val0, val1);
789 __m256i t2_val1 = zip(val2, val3);
790 __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
791 __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
792 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
793 _mm256_extractf128_si256(t4_val0, 0));
794 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
795 _mm256_extractf128_si256(t4_val1, 0));
796 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
797 _mm256_extractf128_si256(t4_val0, 1));
798 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
799 _mm256_extractf128_si256(t4_val1, 1));
800 src_ptr += 16;
801 packed_ptr += packed_stride * 16;
802 }
803 for (; col < src_end_col; col++) {
804 std::int32_t accum = 0;
805 for (int r = 0; r < 4; r++) {
806 std::int8_t packed_val;
807 if (block_row + r < src_rows) {
808 packed_val = input_xor ^ src_ptr[r * src_stride];
809 } else {
810 packed_val = input_xor ^ src_zero_point;
811 }
812 accum += packed_val;
813 *packed_ptr++ = packed_val;
814 }
815 if (sums) {
816 sums[col] += accum;
817 }
818 src_ptr++;
819 }
820 for (; col < end_col; col++) {
821 std::memset(packed_ptr, 0, 4);
822 packed_ptr += 4;
823 }
824 }
825
826 #endif // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
827
828 } // namespace ruy
829