• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <cstring>
18 
19 #include "ruy/check_macros.h"
20 #include "ruy/opt_set.h"
21 #include "ruy/pack_x86.h"
22 #include "ruy/path.h"
23 #include "ruy/platform.h"
24 #include "ruy/profiler/instrumentation.h"
25 
26 #if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
27 #include <immintrin.h>  // IWYU pragma: keep
28 #endif
29 
30 namespace ruy {
31 
32 #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
33 
Pack8bitColMajorForAvx512(const std::int8_t *,std::int8_t,const std::int8_t *,int,int,int,std::int8_t *,std::int32_t *)34 void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t,
35                                const std::int8_t*, int, int, int, std::int8_t*,
36                                std::int32_t*) {
37   // CPU-ID-based checks should disable the path that would reach this point.
38   RUY_DCHECK(false);
39 }
40 
PackFloatColMajorForAvx512(const float *,const float *,int,int,int,float *)41 void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
42                                 float*) {
43   // CPU-ID-based checks should disable the path that would reach this point.
44   RUY_DCHECK(false);
45 }
46 
Pack8bitRowMajorForAvx512(const std::uint8_t *,int,int,std::int8_t *,int,int,int,int,int,int,int,std::int32_t *)47 void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
48                                int, int, int, int, int, int, std::int32_t*) {
49   RUY_DCHECK(false);
50 }
51 
52 #else  // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
53 
54 // The first int8_t template parameter is arbitrary: this routine is common to
55 // all 8-bit source matrix types.
56 using PackImpl8bitAvx512 =
57     PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
58              std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
59 
60 namespace {
61 
62 inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
63                                std::int8_t* packed_ptr) {
64   using Layout = PackImpl8bitAvx512::Layout;
65   static constexpr int kHalfLayoutCols =
66       PackImpl8bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
67                                             // block.
68   RUY_DCHECK_EQ(kHalfLayoutCols, 8);
69   RUY_DCHECK_EQ(Layout::kCols, 16);
70   RUY_DCHECK_EQ(Layout::kRows, 4);
71 
72   const int non_trailing_blocks = (src_rows & ~31) >> 2;
73   // This routine fills half blocks, and typically fills the second halves.
74   // Thus packed_ptr is already offset by 8 * 4.
75   for (int k = 0; k < non_trailing_blocks; ++k) {
76     for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
77       packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
78     }
79   }
80 }
81 
82 inline __m512i LoaduTwo(const std::int8_t* addr_lo,
83                         const std::int8_t* addr_hi) {
84   __m512i lower_filled = _mm512_castsi256_si512(
85       _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
86   return _mm512_inserti32x8(
87       lower_filled,
88       _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1);
89 }
90 
91 inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
92                             const std::int8_t* addr_lo,
93                             const std::int8_t* addr_hi) {
94   const __m512i lower_filled = _mm512_castsi256_si512(
95       _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo));
96   return _mm512_inserti32x8(
97       lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi),
98       1);
99 }
100 
101 inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
102                                std::int8_t input_xor,
103                                const std::int8_t* zerobuf, int src_stride,
104                                int remaining_src_cols, int src_rows,
105                                std::int8_t* packed_ptr, std::int32_t* sums_ptr,
106                                std::int8_t* trailing_buf) {
107   using Layout = PackImpl8bitAvx512::Layout;
108   RUY_DCHECK_EQ(Layout::kCols, 16);
109   RUY_DCHECK_EQ(Layout::kRows, 4);
110   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
111   // We process 8 of these chunks at a time, padding short input chunks.
112   constexpr int kNumRowChunks = 8;
113   constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
114 
115   const std::int8_t* src_ptr0 = src_ptr;
116   const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
117   const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
118   const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
119   const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
120   const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
121   const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
122   const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
123   std::int64_t src_inc0 = kNumChunkedSrcRows;
124   std::int64_t src_inc1 = kNumChunkedSrcRows;
125   std::int64_t src_inc2 = kNumChunkedSrcRows;
126   std::int64_t src_inc3 = kNumChunkedSrcRows;
127   std::int64_t src_inc4 = kNumChunkedSrcRows;
128   std::int64_t src_inc5 = kNumChunkedSrcRows;
129   std::int64_t src_inc6 = kNumChunkedSrcRows;
130   std::int64_t src_inc7 = kNumChunkedSrcRows;
131   // Handle cases where source does not have kHalfLayoutCols (8) columns.
132   if (remaining_src_cols < 8) {
133     if (remaining_src_cols <= 0) {
134       src_ptr0 = zerobuf;
135       src_inc0 = 0;
136     }
137     if (remaining_src_cols <= 1) {
138       src_ptr1 = zerobuf;
139       src_inc1 = 0;
140     }
141     if (remaining_src_cols <= 2) {
142       src_ptr2 = zerobuf;
143       src_inc2 = 0;
144     }
145     if (remaining_src_cols <= 3) {
146       src_ptr3 = zerobuf;
147       src_inc3 = 0;
148     }
149     if (remaining_src_cols <= 4) {
150       src_ptr4 = zerobuf;
151       src_inc4 = 0;
152     }
153     if (remaining_src_cols <= 5) {
154       src_ptr5 = zerobuf;
155       src_inc5 = 0;
156     }
157     if (remaining_src_cols <= 6) {
158       src_ptr6 = zerobuf;
159       src_inc6 = 0;
160     }
161     src_ptr7 = zerobuf;
162     src_inc7 = 0;
163   }
164 
165   const std::int8_t zero_point = zerobuf[0];
166 
167   if (sums_ptr) {
168     // i: kHalfLayoutCols.
169     for (int i = 0; i < 8; ++i) {
170       sums_ptr[i] = 0;
171     }
172   }
173   std::int32_t sums_adjustment = 0;
174   const __m512i ones_16bit = _mm512_set1_epi16(1);
175   __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
176 
177   // The overall packing effectively pads the source rows to
178   // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
179   // only pack for (src_rows + 31) & ~31. When there is an incomplete
180   // destination block, this is stored into trailing_buf instead of packed_ptr.
181   for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
182     // m: {0, 1} for 2 chunks of rows.
183     for (int m = 0; m < 2; ++m) {
184       // Available source rows.
185       // If this is less than 0 (for m=1), we skip, having filled trailing
186       // buffer for m=0. Also, if source rows is zero on m=1, then we filled
187       // exactly to the end of the column in the packed buffer.
188       const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
189       // Effectively,
190       // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
191       // treat each case separately.
192       if (available_src_rows >= kNumChunkedSrcRows) {
193         // i: chunks, s: Layout::Rows.
194         if (sums_ptr) {
195           __m512i t0, t1, t2, t3;
196           __m512i r0, r1, r2, r3;
197           const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
198 
199           t0 = LoaduTwo(src_ptr0, src_ptr4);
200           t1 = LoaduTwo(src_ptr1, src_ptr5);
201           t2 = LoaduTwo(src_ptr2, src_ptr6);
202           t3 = LoaduTwo(src_ptr3, src_ptr7);
203 
204           r0 = _mm512_unpacklo_epi32(t0, t1);
205           r2 = _mm512_unpackhi_epi32(t0, t1);
206           r1 = _mm512_unpacklo_epi32(t2, t3);
207           r3 = _mm512_unpackhi_epi32(t2, t3);
208 
209           t0 = _mm512_unpacklo_epi64(r0, r1);
210           t2 = _mm512_unpackhi_epi64(r0, r1);
211           t1 = _mm512_unpacklo_epi64(r2, r3);
212           t3 = _mm512_unpackhi_epi64(r2, r3);
213 
214           r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
215           r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
216           r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
217           r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
218 
219           r0 = _mm512_xor_si512(r0, input_xor_v);
220           r1 = _mm512_xor_si512(r1, input_xor_v);
221           r2 = _mm512_xor_si512(r2, input_xor_v);
222           r3 = _mm512_xor_si512(r3, input_xor_v);
223 
224           const __m256i r0_0 = _mm512_castsi512_si256(r0);
225           const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
226           const __m256i r1_0 = _mm512_castsi512_si256(r1);
227           const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
228           const __m256i r2_0 = _mm512_castsi512_si256(r2);
229           const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
230           const __m256i r3_0 = _mm512_castsi512_si256(r3);
231           const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
232 
233           __m512i sums_8x4_16bit;
234           sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
235           sums_8x4_16bit =
236               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
237           sums_8x4_16bit =
238               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
239           sums_8x4_16bit =
240               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
241           sums_8x4_16bit =
242               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
243           sums_8x4_16bit =
244               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
245           sums_8x4_16bit =
246               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
247           sums_8x4_16bit =
248               _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
249           // The sums have been performed across columns, and now we have
250           // 4x16-bit sums packed together. We use madd for pairwise 32-bit
251           // sums.
252           const __m512i sums_8x2_32bit_new =
253               _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
254           sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
255 
256           _mm256_storeu_si256(
257               reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
258           _mm256_storeu_si256(
259               reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
260           _mm256_storeu_si256(
261               reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
262           _mm256_storeu_si256(
263               reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
264           _mm256_storeu_si256(
265               reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
266           _mm256_storeu_si256(
267               reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
268           _mm256_storeu_si256(
269               reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
270           _mm256_storeu_si256(
271               reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
272         } else {
273           __m512i t0, t1, t2, t3;
274           __m512i r0, r1, r2, r3;
275           const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
276 
277           t0 = LoaduTwo(src_ptr0, src_ptr4);
278           t1 = LoaduTwo(src_ptr1, src_ptr5);
279           t2 = LoaduTwo(src_ptr2, src_ptr6);
280           t3 = LoaduTwo(src_ptr3, src_ptr7);
281 
282           r0 = _mm512_unpacklo_epi32(t0, t1);
283           r2 = _mm512_unpackhi_epi32(t0, t1);
284           r1 = _mm512_unpacklo_epi32(t2, t3);
285           r3 = _mm512_unpackhi_epi32(t2, t3);
286 
287           t0 = _mm512_unpacklo_epi64(r0, r1);
288           t2 = _mm512_unpackhi_epi64(r0, r1);
289           t1 = _mm512_unpacklo_epi64(r2, r3);
290           t3 = _mm512_unpackhi_epi64(r2, r3);
291 
292           r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
293           r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
294           r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
295           r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
296 
297           r0 = _mm512_xor_si512(r0, input_xor_v);
298           r1 = _mm512_xor_si512(r1, input_xor_v);
299           r2 = _mm512_xor_si512(r2, input_xor_v);
300           r3 = _mm512_xor_si512(r3, input_xor_v);
301 
302           const __m256i r0_0 = _mm512_castsi512_si256(r0);
303           const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
304           const __m256i r1_0 = _mm512_castsi512_si256(r1);
305           const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
306           const __m256i r2_0 = _mm512_castsi512_si256(r2);
307           const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
308           const __m256i r3_0 = _mm512_castsi512_si256(r3);
309           const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
310           _mm256_storeu_si256(
311               reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
312           _mm256_storeu_si256(
313               reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
314           _mm256_storeu_si256(
315               reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
316           _mm256_storeu_si256(
317               reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
318           _mm256_storeu_si256(
319               reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
320           _mm256_storeu_si256(
321               reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
322           _mm256_storeu_si256(
323               reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
324           _mm256_storeu_si256(
325               reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
326         }
327       } else if (available_src_rows > 0) {
328         RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
329         const __mmask32 row_mask =
330             (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
331 
332         // We do not care what goes into the trailing buffer, but we want
333         // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
334         //
335         // We compensate for padding-with-zero_point by initializing the
336         // summations with the compensating offset, effectively
337         // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
338         //                         4 * (8 - ((available_src_rows + 3) >> 2)).
339         //
340         // Note that (zero_point ^ input_xor) is performed in 8-bits and then
341         // cast.
342         sums_adjustment += -(zero_point ^ input_xor) * 4 *
343                            (8 - ((available_src_rows + 3) >> 2));
344 
345         __m512i t0, t1, t2, t3;
346         __m512i r0, r1, r2, r3;
347         const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
348         const __m256i zero_point_v = _mm256_set1_epi8(zero_point);
349 
350         t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
351         t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
352         t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
353         t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
354 
355         r0 = _mm512_unpacklo_epi32(t0, t1);
356         r2 = _mm512_unpackhi_epi32(t0, t1);
357         r1 = _mm512_unpacklo_epi32(t2, t3);
358         r3 = _mm512_unpackhi_epi32(t2, t3);
359 
360         t0 = _mm512_unpacklo_epi64(r0, r1);
361         t2 = _mm512_unpackhi_epi64(r0, r1);
362         t1 = _mm512_unpacklo_epi64(r2, r3);
363         t3 = _mm512_unpackhi_epi64(r2, r3);
364 
365         r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
366         r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
367         r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
368         r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
369 
370         r0 = _mm512_xor_si512(r0, input_xor_v);
371         r1 = _mm512_xor_si512(r1, input_xor_v);
372         r2 = _mm512_xor_si512(r2, input_xor_v);
373         r3 = _mm512_xor_si512(r3, input_xor_v);
374 
375         const __m256i r0_0 = _mm512_castsi512_si256(r0);
376         const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
377         const __m256i r1_0 = _mm512_castsi512_si256(r1);
378         const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
379         const __m256i r2_0 = _mm512_castsi512_si256(r2);
380         const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
381         const __m256i r3_0 = _mm512_castsi512_si256(r3);
382         const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
383 
384         __m512i sums_8x4_16bit;
385         sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
386         sums_8x4_16bit =
387             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
388         sums_8x4_16bit =
389             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
390         sums_8x4_16bit =
391             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
392         sums_8x4_16bit =
393             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
394         sums_8x4_16bit =
395             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
396         sums_8x4_16bit =
397             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
398         sums_8x4_16bit =
399             _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
400         // The sums have been performed across columns, and now we have
401         // 4x16-bit sums packed together. We use madd for pairwise 32-bit
402         // sums.
403         const __m512i sums_8x2_32bit_new =
404             _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
405         sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
406 
407         _mm256_storeu_si256(
408             reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0);
409         _mm256_storeu_si256(
410             reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1);
411         _mm256_storeu_si256(
412             reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0);
413         _mm256_storeu_si256(
414             reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1);
415         _mm256_storeu_si256(
416             reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0);
417         _mm256_storeu_si256(
418             reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1);
419         _mm256_storeu_si256(
420             reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0);
421         _mm256_storeu_si256(
422             reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1);
423       }
424 
425       packed_ptr += 16 * kNumChunkedSrcRows;
426       src_ptr0 += src_inc0;
427       src_ptr1 += src_inc1;
428       src_ptr2 += src_inc2;
429       src_ptr3 += src_inc3;
430       src_ptr4 += src_inc4;
431       src_ptr5 += src_inc5;
432       src_ptr6 += src_inc6;
433       src_ptr7 += src_inc7;
434     }
435   }
436 
437   if (sums_ptr) {
438     const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
439 
440     __m256i sums =
441         _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
442     const __m512i idx =
443         _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
444 
445     // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
446     // neighbours, finshing up by adding them to the stored accumulated sums.
447     const __m512i sums_2x8_32bit =
448         _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
449     sums = _mm256_add_epi32(sums, sums_adjustment_v);
450     sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
451     sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
452 
453     _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
454   }
455 }
456 
457 inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
458   const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
459   return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
460 }
461 
462 inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
463                            const float* addr_hi) {
464   const __m512 lower_filled =
465       _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
466   return _mm512_insertf32x8(lower_filled,
467                             _mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
468 }
469 
470 inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
471   return _mm512_castpd_ps(
472       _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
473 }
474 
475 inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
476   return _mm512_castpd_ps(
477       _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
478 }
479 
480 inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
481                                 int src_stride, int remaining_src_cols,
482                                 int src_rows, float* packed_ptr,
483                                 float* trailing_buf) {
484   const float* src_ptr0 = src_ptr;
485   const float* src_ptr1 = src_ptr0 + src_stride;
486   const float* src_ptr2 = src_ptr1 + src_stride;
487   const float* src_ptr3 = src_ptr2 + src_stride;
488   const float* src_ptr4 = src_ptr3 + src_stride;
489   const float* src_ptr5 = src_ptr4 + src_stride;
490   const float* src_ptr6 = src_ptr5 + src_stride;
491   const float* src_ptr7 = src_ptr6 + src_stride;
492   std::int64_t src_inc0 = 8;
493   std::int64_t src_inc1 = 8;
494   std::int64_t src_inc2 = 8;
495   std::int64_t src_inc3 = 8;
496   std::int64_t src_inc4 = 8;
497   std::int64_t src_inc5 = 8;
498   std::int64_t src_inc6 = 8;
499   std::int64_t src_inc7 = 8;
500   if (remaining_src_cols < 8) {
501     if (remaining_src_cols <= 0) {
502       src_ptr0 = zerobuf;
503       src_inc0 = 0;
504     }
505     if (remaining_src_cols <= 1) {
506       src_ptr1 = zerobuf;
507       src_inc1 = 0;
508     }
509     if (remaining_src_cols <= 2) {
510       src_ptr2 = zerobuf;
511       src_inc2 = 0;
512     }
513     if (remaining_src_cols <= 3) {
514       src_ptr3 = zerobuf;
515       src_inc3 = 0;
516     }
517     if (remaining_src_cols <= 4) {
518       src_ptr4 = zerobuf;
519       src_inc4 = 0;
520     }
521     if (remaining_src_cols <= 5) {
522       src_ptr5 = zerobuf;
523       src_inc5 = 0;
524     }
525     if (remaining_src_cols <= 6) {
526       src_ptr6 = zerobuf;
527       src_inc6 = 0;
528     }
529     src_ptr7 = zerobuf;
530     src_inc7 = 0;
531   }
532 
533   for (int k = 0; k < src_rows; k += 16) {
534     for (int m = 0; m < 2; ++m) {
535       const int available_src_rows = src_rows - k - 8 * m;
536       // Effectively,
537       // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
538       // but treat each case separately.
539       if (available_src_rows > 7) {
540         __m512 t0, t1, t2, t3;
541         __m512 r0, r1, r2, r3;
542 
543         t0 = LoaduTwo(src_ptr0, src_ptr4);
544         t1 = LoaduTwo(src_ptr1, src_ptr5);
545         t2 = LoaduTwo(src_ptr2, src_ptr6);
546         t3 = LoaduTwo(src_ptr3, src_ptr7);
547 
548         r0 = _mm512_unpacklo_ps(t0, t1);
549         r2 = _mm512_unpackhi_ps(t0, t1);
550         r1 = _mm512_unpacklo_ps(t2, t3);
551         r3 = _mm512_unpackhi_ps(t2, t3);
552 
553         t0 = Mm512UnpackloPsx2(r0, r1);
554         t2 = Mm512UnpackhiPsx2(r0, r1);
555         t1 = Mm512UnpackloPsx2(r2, r3);
556         t3 = Mm512UnpackhiPsx2(r2, r3);
557 
558         r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
559         r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
560         r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
561         r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
562 
563         _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
564         _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
565         _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
566         _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
567         _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
568         _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
569         _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
570         _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
571       } else if (available_src_rows > 0) {
572         const __mmask8 row_mask =
573             (static_cast<std::uint32_t>(1) << available_src_rows) - 1;
574 
575         __m512 t0, t1, t2, t3;
576         __m512 r0, r1, r2, r3;
577 
578         t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
579         t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
580         t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
581         t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
582 
583         r0 = _mm512_unpacklo_ps(t0, t1);
584         r2 = _mm512_unpackhi_ps(t0, t1);
585         r1 = _mm512_unpacklo_ps(t2, t3);
586         r3 = _mm512_unpackhi_ps(t2, t3);
587 
588         t0 = Mm512UnpackloPsx2(r0, r1);
589         t2 = Mm512UnpackhiPsx2(r0, r1);
590         t1 = Mm512UnpackloPsx2(r2, r3);
591         t3 = Mm512UnpackhiPsx2(r2, r3);
592 
593         r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
594         r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
595         r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
596         r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
597 
598         _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
599         _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
600         _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
601         _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
602         _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
603         _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
604         _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
605         // Do not store _mm512_extractf32x8_ps(r3, 1).
606       }
607 
608       packed_ptr += 16 * 8;
609       src_ptr0 += src_inc0;
610       src_ptr1 += src_inc1;
611       src_ptr2 += src_inc2;
612       src_ptr3 += src_inc3;
613       src_ptr4 += src_inc4;
614       src_ptr5 += src_inc5;
615       src_ptr6 += src_inc6;
616       src_ptr7 += src_inc7;
617     }
618   }
619 }
620 
621 inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
622   const int non_trailing_rows = src_rows & ~7;
623   for (int k = 0; k < non_trailing_rows; ++k) {
624     for (int j = 0; j < 8; ++j) {
625       packed_ptr[j] = 0.0f;
626     }
627     packed_ptr += 16;
628   }
629 }
630 
631 }  // namespace.
632 
633 void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
634                                std::int8_t input_xor,
635                                const std::int8_t* zerobuf, int src_stride,
636                                int remaining_src_cols, int src_rows,
637                                std::int8_t* packed_ptr,
638                                std::int32_t* sums_ptr) {
639   profiler::ScopeLabel label("Pack kAvx512 8bit");
640 
641   using Layout = PackImpl8bitAvx512::Layout;
642   constexpr int kHalfBlockOffset = 32;
643   RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
644   static constexpr int kHalfLayoutCols =
645       PackImpl8bitAvx512::kHalfLayoutCols;  // Half the number of cols in a
646                                             // block.
647   RUY_DCHECK_EQ(kHalfLayoutCols, 8);
648   RUY_DCHECK_EQ(Layout::kCols, 16);
649   RUY_DCHECK_EQ(Layout::kRows, 4);
650 
651   // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
652   // We process 8 of these chunks at a time, padding short input chunks.
653   constexpr int kNumRowChunks = 8;
654 
655   // Each packed block is 4*16, and there are normally 8. The trailing block is
656   // only slightly shorter.
657   constexpr int kTrailingBufSize =
658       kNumRowChunks * Layout::kCols * Layout::kRows;
659   std::int8_t trailing_buf[kTrailingBufSize];
660   memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
661 
662   std::int32_t* second_sums_ptr =
663       sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
664   if (remaining_src_cols > kHalfLayoutCols) {
665     HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
666                        remaining_src_cols, src_rows, packed_ptr, sums_ptr,
667                        trailing_buf);
668     HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
669                        zerobuf, src_stride,
670                        remaining_src_cols - kHalfLayoutCols, src_rows,
671                        packed_ptr + kHalfBlockOffset, second_sums_ptr,
672                        trailing_buf + kHalfBlockOffset);
673   } else {
674     HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
675                        remaining_src_cols, src_rows, packed_ptr, sums_ptr,
676                        trailing_buf);
677     ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
678                        packed_ptr + kHalfBlockOffset);
679     // The kernel may not need the second half-blocks sums to be set.
680     if (second_sums_ptr) {
681       for (int i = 0; i < kHalfLayoutCols; ++i) {
682         second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
683       }
684     }
685   }
686   constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
687   const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
688   // If the number of source rows is not a multiple of kChunkedRowMask, there
689   // will be data in the trailing buffer,
690   if (trailing_data) {
691     const int non_trailing_rows = src_rows & ~kChunkedRowMask;
692     // Destination "rows" are padded to next highest multiple of Layout::kRows.
693     const int dst_rows = (src_rows + 3) & ~3;
694     const int trailing_rows = dst_rows - non_trailing_rows;
695     memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
696            Layout::kCols * trailing_rows * sizeof(std::int8_t));
697   }
698 }
699 
700 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
701                                 int src_stride, int remaining_src_cols,
702                                 int src_rows, float* packed_ptr) {
703   profiler::ScopeLabel label("Pack kAvx512 float");
704   float trailing_buf[7 * 16];
705   if (remaining_src_cols > 8) {
706     HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
707                         src_rows, packed_ptr, trailing_buf);
708     HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
709                         remaining_src_cols - 8, src_rows, packed_ptr + 8,
710                         trailing_buf + 8);
711   } else {
712     memset(trailing_buf, 0, sizeof(trailing_buf));
713     HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
714                         src_rows, packed_ptr, trailing_buf);
715     ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
716   }
717   const int trailing_rows = src_rows & 7;
718   if (trailing_rows > 0) {
719     const int non_trailing_rows = src_rows & ~7;
720     memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
721            16 * trailing_rows * sizeof(float));
722   }
723 }
724 
725 void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
726                                int src_zero_point, std::int8_t* packed_ptr,
727                                int packed_stride, int start_col, int end_col,
728                                int src_cols, int block_row, int src_rows,
729                                int input_xor, std::int32_t* sums) {
730   int col = start_col;
731   int src_end_col = std::min(end_col, src_cols);
732 
733   for (; col <= src_end_col - 16; col += 16) {
734     std::int8_t* dst_ptr = packed_ptr;
735     __m128i val0, val1, val2, val3;
736     __m128i input_xor_dup = _mm_set1_epi8(input_xor);
737     // Load a 4x16 block.
738     if (block_row + 4 <= src_rows) {
739       val0 = _mm_loadu_si128(
740           reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
741       val1 = _mm_loadu_si128(
742           reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
743       val2 = _mm_loadu_si128(
744           reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
745       val3 = _mm_loadu_si128(
746           reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
747     } else {
748       val0 = _mm_set1_epi8(src_zero_point);
749       val1 = val0;
750       val2 = val0;
751       val3 = val0;
752       if (block_row + 0 < src_rows)
753         val0 = _mm_loadu_si128(
754             reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
755       if (block_row + 1 < src_rows)
756         val1 = _mm_loadu_si128(
757             reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
758       if (block_row + 2 < src_rows)
759         val2 = _mm_loadu_si128(
760             reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
761       if (block_row + 3 < src_rows)
762         val3 = _mm_loadu_si128(
763             reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
764     }
765     // Maybe xor the sign bit to convert from uint8 to int8.
766     val0 = _mm_xor_si128(val0, input_xor_dup);
767     val1 = _mm_xor_si128(val1, input_xor_dup);
768     val2 = _mm_xor_si128(val2, input_xor_dup);
769     val3 = _mm_xor_si128(val3, input_xor_dup);
770     // Update the sums.
771     __m256i val16_0 = _mm256_cvtepi8_epi16(val0);
772     __m256i val16_1 = _mm256_cvtepi8_epi16(val1);
773     __m256i val16_2 = _mm256_cvtepi8_epi16(val2);
774     __m256i val16_3 = _mm256_cvtepi8_epi16(val3);
775     __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
776                                          _mm256_add_epi16(val16_2, val16_3));
777     __m512i sum =
778         _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
779     sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
780     _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
781     auto zip = [](__m128i x, __m128i y) {
782       auto perm_64_0_64_0 = [](__m128i x) {
783         return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
784                                         _mm256_castsi128_si256(x));
785       };
786       return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
787     };
788     __m256i t2_val0 = zip(val0, val1);
789     __m256i t2_val1 = zip(val2, val3);
790     __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
791     __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
792     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
793                      _mm256_extractf128_si256(t4_val0, 0));
794     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
795                      _mm256_extractf128_si256(t4_val1, 0));
796     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
797                      _mm256_extractf128_si256(t4_val0, 1));
798     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
799                      _mm256_extractf128_si256(t4_val1, 1));
800     src_ptr += 16;
801     packed_ptr += packed_stride * 16;
802   }
803   for (; col < src_end_col; col++) {
804     std::int32_t accum = 0;
805     for (int r = 0; r < 4; r++) {
806       std::int8_t packed_val;
807       if (block_row + r < src_rows) {
808         packed_val = input_xor ^ src_ptr[r * src_stride];
809       } else {
810         packed_val = input_xor ^ src_zero_point;
811       }
812       accum += packed_val;
813       *packed_ptr++ = packed_val;
814     }
815     if (sums) {
816       sums[col] += accum;
817     }
818     src_ptr++;
819   }
820   for (; col < end_col; col++) {
821     std::memset(packed_ptr, 0, 4);
822     packed_ptr += 4;
823   }
824 }
825 
826 #endif  // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
827 
828 }  // namespace ruy
829