• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_PACK_X86_H_
17 #define RUY_RUY_PACK_X86_H_
18 
19 #include <cstdint>
20 #include <cstring>
21 #include <type_traits>
22 
23 #include "ruy/check_macros.h"
24 #include "ruy/mat.h"
25 #include "ruy/opt_set.h"
26 #include "ruy/pack_common.h"
27 #include "ruy/path.h"
28 #include "ruy/platform.h"
29 #include "ruy/profiler/instrumentation.h"
30 #include "ruy/tune.h"
31 
32 namespace ruy {
33 
34 #if RUY_PLATFORM_X86
35 
36 RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx)
37 RUY_INHERIT_PACK(Path::kAvx, Path::kAvx2Fma)
38 RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512)
39 
40 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx2Fma, 8)
41 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx512, 16)
42 
43 template <>
44 struct PackedTypeImpl<Path::kAvx, std::uint8_t> {
45   using Type = std::int8_t;
46 };
47 
48 template <>
49 struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> {
50   using Type = std::int8_t;
51 };
52 template <>
53 struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
54   using Type = std::int8_t;
55 };
56 
57 // Note that source and zero buffers can be uint8 type, but in the packing
58 // function are reinterpreted as int8, and are XOR-ed with input_xor.
59 void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
60                              const std::int8_t* zerobuf, int src_stride,
61                              int remaining_src_cols, int src_rows,
62                              std::int8_t* packed_ptr, std::int32_t* sums_ptr);
63 
64 template <typename Scalar>
65 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
66                 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
67   static_assert(std::is_same<Scalar, std::int8_t>::value ||
68                     std::is_same<Scalar, std::uint8_t>::value,
69                 "");
70   using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
71   static constexpr std::int8_t kInputXor =
72       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
73 
74   static void Run(Tuning, const Mat<Scalar>& src_matrix,
75                   PMat<std::int8_t>* packed_matrix, int start_col,
76                   int end_col) {
77     profiler::ScopeLabel label("Pack (AVX2 8-bit)");
78 
79     RUY_DCHECK(IsColMajor(src_matrix.layout));
80     RUY_DCHECK(IsColMajor(packed_matrix->layout));
81     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
82     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
83     std::int32_t* sums = packed_matrix->sums;
84     Scalar zerobuf[Layout::kCols * Layout::kRows];
85     memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
86            Layout::kCols * Layout::kRows * sizeof(Scalar));
87     for (int block_col = start_col; block_col < end_col;
88          block_col += Layout::kCols) {
89       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
90       int src_stride = src_matrix.layout.stride;
91       const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
92       int remaining_src_cols = src_matrix.layout.cols - block_col;
93 
94       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
95       std::int8_t* packed_ptr =
96           packed_matrix->data +
97           packed_matrix->layout.stride * (block_col & block_col_mask);
98       Pack8bitColMajorForAvx2(
99           reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
100           reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
101           remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
102     }
103   }
104 };
105 
106 void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor,
107                             const std::int8_t* zerobuf, int src_stride,
108                             int remaining_src_cols, int src_rows,
109                             std::int8_t* packed_ptr, std::int32_t* sums_ptr);
110 
111 template <typename Scalar>
112 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
113                 std::int8_t, std::int32_t, Order::kColMajor> {
114   static_assert(std::is_same<Scalar, std::int8_t>::value ||
115                     std::is_same<Scalar, std::uint8_t>::value,
116                 "");
117   using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
118   static constexpr std::int8_t kInputXor =
119       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
120 
121   static void Run(Tuning, const Mat<Scalar>& src_matrix,
122                   PMat<std::int8_t>* packed_matrix, int start_col,
123                   int end_col) {
124     profiler::ScopeLabel label("Pack (AVX 8-bit)");
125 
126     RUY_DCHECK(IsColMajor(src_matrix.layout));
127     RUY_DCHECK(IsColMajor(packed_matrix->layout));
128     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
129     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
130     std::int32_t* sums = packed_matrix->sums;
131     Scalar zerobuf[Layout::kCols * Layout::kRows];
132     memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
133            Layout::kCols * Layout::kRows * sizeof(Scalar));
134     for (int block_col = start_col; block_col < end_col;
135          block_col += Layout::kCols) {
136       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
137       int src_stride = src_matrix.layout.stride;
138       const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
139       int remaining_src_cols = src_matrix.layout.cols - block_col;
140 
141       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
142       std::int8_t* packed_ptr =
143           packed_matrix->data +
144           packed_matrix->layout.stride * (block_col & block_col_mask);
145       Pack8bitColMajorForAvx(
146           reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
147           reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
148           remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
149     }
150   }
151 };
152 
153 void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf,
154                              int src_stride, int remaining_src_cols,
155                              int src_rows, float* packed_ptr);
156 
157 template <>
158 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
159                 float, float, Order::kColMajor> {
160   using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
161   static void Run(Tuning, const Mat<float>& src_matrix,
162                   PMat<float>* packed_matrix, int start_col, int end_col) {
163     profiler::ScopeLabel label("Pack (AVX float)");
164 
165     RUY_DCHECK(IsColMajor(src_matrix.layout));
166     RUY_DCHECK(IsColMajor(packed_matrix->layout));
167     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
168     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
169     const float zerobuf[Layout::kCols] = {
170         0.0f};  // Remainder default inits to 0.0f.
171     for (int block_col = start_col; block_col < end_col;
172          block_col += Layout::kCols) {
173       int src_stride = src_matrix.layout.stride;
174       const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
175       int remaining_src_cols = src_matrix.layout.cols - block_col;
176 
177       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
178       float* packed_ptr =
179           packed_matrix->data +
180           packed_matrix->layout.stride * (block_col & block_col_mask);
181       PackFloatColMajorForAvx(src_ptr, zerobuf, src_stride, remaining_src_cols,
182                               src_matrix.layout.rows, packed_ptr);
183     }
184   }
185 };
186 
187 void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
188                               int src_stride, int remaining_src_cols,
189                               int src_rows, float* packed_ptr);
190 
191 template <>
192 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>,
193                 float, float, float, Order::kColMajor> {
194   using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
195   static void Run(Tuning, const Mat<float>& src_matrix,
196                   PMat<float>* packed_matrix, int start_col, int end_col) {
197     profiler::ScopeLabel label("Pack (AVX2 float)");
198 
199     RUY_DCHECK(IsColMajor(src_matrix.layout));
200     RUY_DCHECK(IsColMajor(packed_matrix->layout));
201     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
202     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
203     const float zerobuf[Layout::kCols] = {
204         0.0f};  // Remainder default inits to 0.0f.
205     for (int block_col = start_col; block_col < end_col;
206          block_col += Layout::kCols) {
207       int src_stride = src_matrix.layout.stride;
208       const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
209       int remaining_src_cols = src_matrix.layout.cols - block_col;
210 
211       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
212       float* packed_ptr =
213           packed_matrix->data +
214           packed_matrix->layout.stride * (block_col & block_col_mask);
215       PackFloatColMajorForAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols,
216                                src_matrix.layout.rows, packed_ptr);
217     }
218   }
219 };
220 
221 // Note that source and zero buffers can be uint8 type, but in the packing
222 // function are reinterpreted as int8, and are XOR-ed with input_xor.
223 void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
224                                std::int8_t input_xor,
225                                const std::int8_t* zerobuf, int src_stride,
226                                int remaining_src_cols, int src_rows,
227                                std::int8_t* packed_ptr, std::int32_t* sums_ptr);
228 
229 template <typename Scalar>
230 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
231                 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
232   static_assert(std::is_same<Scalar, std::int8_t>::value ||
233                     std::is_same<Scalar, std::uint8_t>::value,
234                 "");
235   using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
236   static constexpr int kHalfLayoutCols =
237       8;  // Half the number of cols in a block.
238   static constexpr std::int8_t kInputXor =
239       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
240 
241   static void Run(Tuning, const Mat<Scalar>& src_matrix,
242                   PMat<std::int8_t>* packed_matrix, int start_col,
243                   int end_col) {
244     profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
245 
246     RUY_DCHECK(IsColMajor(src_matrix.layout));
247     RUY_DCHECK(IsColMajor(packed_matrix->layout));
248     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
249     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
250     RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
251     std::int32_t* sums = packed_matrix->sums;
252     Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
253     memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
254            kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
255     for (int block_col = start_col; block_col < end_col;
256          block_col += Layout::kCols) {
257       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
258       int src_stride = src_matrix.layout.stride;
259       const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
260       int remaining_src_cols = src_matrix.layout.cols - block_col;
261 
262       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
263       std::int8_t* packed_ptr =
264           packed_matrix->data +
265           packed_matrix->layout.stride * (block_col & block_col_mask);
266       Pack8bitColMajorForAvx512(
267           reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
268           reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
269           remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
270     }
271   }
272 };
273 
274 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
275                                 int src_stride, int remaining_src_cols,
276                                 int src_rows, float* packed_ptr);
277 
278 template <>
279 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
280                 float, float, float, Order::kColMajor> {
281   static void Run(Tuning, const Mat<float>& src_matrix,
282                   PMat<float>* packed_matrix, int start_col, int end_col) {
283     profiler::ScopeLabel label("Pack (AVX-512 float)");
284     using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
285     RUY_DCHECK(IsColMajor(src_matrix.layout));
286     RUY_DCHECK(IsColMajor(packed_matrix->layout));
287     RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
288     RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
289     const float zerobuf[Layout::kCols] = {
290         0.0f};  // Remainder default inits to 0.0f.
291     for (int block_col = start_col; block_col < end_col;
292          block_col += Layout::kCols) {
293       int src_stride = src_matrix.layout.stride;
294       const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
295       int remaining_src_cols = src_matrix.layout.cols - block_col;
296 
297       static constexpr int block_col_mask = ~(Layout::kCols - 1);  // High bits.
298       float* packed_ptr =
299           packed_matrix->data +
300           packed_matrix->layout.stride * (block_col & block_col_mask);
301       PackFloatColMajorForAvx512(src_ptr, zerobuf, src_stride,
302                                  remaining_src_cols, src_matrix.layout.rows,
303                                  packed_ptr);
304     }
305   }
306 };
307 
308 void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
309                              int src_zero_point, std::int8_t* packed_ptr,
310                              int packed_stride, int start_col, int end_col,
311                              int src_cols, int block_row, int src_rows,
312                              int input_xor, std::int32_t* sums);
313 
314 template <typename Scalar>
315 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
316                 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
317   static void Run(Tuning, const Mat<Scalar>& src_matrix,
318                   PMat<std::int8_t>* packed_matrix, int start_col,
319                   int end_col) {
320     profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)");
321     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
322     static constexpr int kInputXor =
323         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
324     std::int32_t* sums = packed_matrix->sums;
325     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
326     int block_row = 0;
327     for (; block_row < packed_matrix->layout.rows; block_row += 4) {
328       int src_stride = src_matrix.layout.stride;
329       int packed_stride = packed_matrix->layout.stride;
330       const Scalar* src_ptr =
331           src_matrix.data.get() + block_row * src_stride + start_col;
332       std::int8_t* packed_ptr =
333           packed_matrix->data + start_col * packed_stride + block_row * 8;
334       Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr),
335                               src_stride, src_matrix.zero_point, packed_ptr,
336                               packed_stride, start_col, end_col,
337                               src_matrix.layout.cols, block_row,
338                               src_matrix.layout.rows, kInputXor, sums);
339     }
340   }
341 };
342 
343 void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride,
344                             int src_zero_point, std::int8_t* packed_ptr,
345                             int packed_stride, int start_col, int end_col,
346                             int src_cols, int block_row, int src_rows,
347                             int input_xor, std::int32_t* sums);
348 
349 template <typename Scalar>
350 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
351                 std::int8_t, std::int32_t, Order::kRowMajor> {
352   static void Run(Tuning, const Mat<Scalar>& src_matrix,
353                   PMat<std::int8_t>* packed_matrix, int start_col,
354                   int end_col) {
355     profiler::ScopeLabel label("Pack (AVX 8bit row-major)");
356     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
357     static constexpr int kInputXor =
358         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
359     std::int32_t* sums = packed_matrix->sums;
360     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
361     int block_row = 0;
362     for (; block_row < packed_matrix->layout.rows; block_row += 4) {
363       int src_stride = src_matrix.layout.stride;
364       int packed_stride = packed_matrix->layout.stride;
365       const Scalar* src_ptr =
366           src_matrix.data.get() + block_row * src_stride + start_col;
367       std::int8_t* packed_ptr =
368           packed_matrix->data + start_col * packed_stride + block_row * 8;
369       Pack8bitRowMajorForAvx(reinterpret_cast<const std::uint8_t*>(src_ptr),
370                              src_stride, src_matrix.zero_point, packed_ptr,
371                              packed_stride, start_col, end_col,
372                              src_matrix.layout.cols, block_row,
373                              src_matrix.layout.rows, kInputXor, sums);
374     }
375   }
376 };
377 
378 void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
379                                int src_zero_point, std::int8_t* packed_ptr,
380                                int packed_stride, int start_col, int end_col,
381                                int src_cols, int block_row, int src_rows,
382                                int input_xor, std::int32_t* sums);
383 
384 template <typename Scalar>
385 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
386                 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
387   static void Run(Tuning, const Mat<Scalar>& src_matrix,
388                   PMat<std::int8_t>* packed_matrix, int start_col,
389                   int end_col) {
390     profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)");
391     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
392     static constexpr int kInputXor =
393         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
394     std::int32_t* sums = packed_matrix->sums;
395     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
396     int block_row = 0;
397     for (; block_row < packed_matrix->layout.rows; block_row += 4) {
398       int src_stride = src_matrix.layout.stride;
399       int packed_stride = packed_matrix->layout.stride;
400       const Scalar* src_ptr =
401           src_matrix.data.get() + block_row * src_stride + start_col;
402       std::int8_t* packed_ptr =
403           packed_matrix->data + start_col * packed_stride + block_row * 16;
404       Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr),
405                                 src_stride, src_matrix.zero_point, packed_ptr,
406                                 packed_stride, start_col, end_col,
407                                 src_matrix.layout.cols, block_row,
408                                 src_matrix.layout.rows, kInputXor, sums);
409     }
410   }
411 };
412 #endif  // RUY_PLATFORM_X86
413 
414 }  // namespace ruy
415 
416 #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
417 
418 #include <immintrin.h>  // IWYU pragma: keep
419 
420 namespace ruy {
421 namespace {
422 
423 template <Path path>
424 inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) {
425   return _mm256_castpd_ps(
426       _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
427 }
428 
429 template <Path path>
430 inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) {
431   return _mm256_castpd_ps(
432       _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
433 }
434 
435 template <Path path>
436 inline __m256i CompareGreaterThan(const __m256i&, const __m256i&) {
437   RUY_DCHECK(false);
438   return _mm256_set1_epi32(0);
439 }
440 
441 // Shared between AVX and AVX2+FMA.
442 template <Path path>
443 inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point,
444                          const std::int8_t* addr) {
445   RUY_DCHECK_LT(available_src_rows, 32);
446   __m256i padded_data;
447 
448   if (available_src_rows >= 16) {
449     __m128i load_hi = _mm_set1_epi8(zero_point);
450     __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr));
451     memcpy(&load_hi, addr + 16, available_src_rows - 16);
452     padded_data = _mm256_set_m128i(load_hi, load_lo);
453   } else {
454     __m128i load_hi = _mm_set1_epi8(zero_point);
455     __m128i load_lo = load_hi;
456     memcpy(&load_lo, addr, available_src_rows);
457     padded_data = _mm256_set_m128i(load_hi, load_lo);
458   }
459   return padded_data;
460 }
461 
462 }  // namespace.
463 
464 template <typename PackImpl, Path path>
465 inline void PackFloatColMajorForAvxCommonPacker(const float* src_ptr,
466                                                 const float* zerobuf,
467                                                 int src_stride,
468                                                 int remaining_src_cols,
469                                                 int src_rows, float* packed_ptr,
470                                                 float* trailing_buf) {
471   RUY_DCHECK_EQ(PackImpl::Layout::kCols, 8);
472   RUY_DCHECK_EQ(PackImpl::Layout::kRows, 1);
473 
474   // This packing amounts to transposition of 8x8 blocks.
475   static constexpr int kPackCols = 8;  // Source cols packed together.
476   static constexpr int kPackRows = 8;  // Short input is padded.
477 
478   const float* src_ptr0 = src_ptr;
479   const float* src_ptr1 = src_ptr0 + src_stride;
480   const float* src_ptr2 = src_ptr1 + src_stride;
481   const float* src_ptr3 = src_ptr2 + src_stride;
482   const float* src_ptr4 = src_ptr3 + src_stride;
483   const float* src_ptr5 = src_ptr4 + src_stride;
484   const float* src_ptr6 = src_ptr5 + src_stride;
485   const float* src_ptr7 = src_ptr6 + src_stride;
486   std::int64_t src_inc0 = 8;
487   std::int64_t src_inc1 = 8;
488   std::int64_t src_inc2 = 8;
489   std::int64_t src_inc3 = 8;
490   std::int64_t src_inc4 = 8;
491   std::int64_t src_inc5 = 8;
492   std::int64_t src_inc6 = 8;
493   std::int64_t src_inc7 = 8;
494   // Handle cases where source does not have kPackDim (8) columns.
495   if (remaining_src_cols < kPackCols) {
496     if (remaining_src_cols <= 0) {
497       src_ptr0 = zerobuf;
498       src_inc0 = 0;
499     }
500     if (remaining_src_cols <= 1) {
501       src_ptr1 = zerobuf;
502       src_inc1 = 0;
503     }
504     if (remaining_src_cols <= 2) {
505       src_ptr2 = zerobuf;
506       src_inc2 = 0;
507     }
508     if (remaining_src_cols <= 3) {
509       src_ptr3 = zerobuf;
510       src_inc3 = 0;
511     }
512     if (remaining_src_cols <= 4) {
513       src_ptr4 = zerobuf;
514       src_inc4 = 0;
515     }
516     if (remaining_src_cols <= 5) {
517       src_ptr5 = zerobuf;
518       src_inc5 = 0;
519     }
520     if (remaining_src_cols <= 6) {
521       src_ptr6 = zerobuf;
522       src_inc6 = 0;
523     }
524     src_ptr7 = zerobuf;
525     src_inc7 = 0;
526   }
527 
528   for (int k = 0; k < src_rows; k += kPackRows) {
529     const int available_src_rows = src_rows - k;
530     // Effectively,
531     // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k));
532     // but treat each case separately.
533     if (available_src_rows >= kPackRows) {
534       __m256 t0, t1, t2, t3, t4, t5, t6, t7;
535       __m256 r0, r1, r2, r3, r4, r5, r6, r7;
536 
537       t0 = _mm256_loadu_ps(src_ptr0);
538       t4 = _mm256_loadu_ps(src_ptr4);
539       t1 = _mm256_loadu_ps(src_ptr1);
540       t5 = _mm256_loadu_ps(src_ptr5);
541       t2 = _mm256_loadu_ps(src_ptr2);
542       t6 = _mm256_loadu_ps(src_ptr6);
543       t3 = _mm256_loadu_ps(src_ptr3);
544       t7 = _mm256_loadu_ps(src_ptr7);
545 
546       r0 = _mm256_unpacklo_ps(t0, t1);
547       r4 = _mm256_unpacklo_ps(t4, t5);
548       r2 = _mm256_unpackhi_ps(t0, t1);
549       r6 = _mm256_unpackhi_ps(t4, t5);
550       r1 = _mm256_unpacklo_ps(t2, t3);
551       r5 = _mm256_unpacklo_ps(t6, t7);
552       r3 = _mm256_unpackhi_ps(t2, t3);
553       r7 = _mm256_unpackhi_ps(t6, t7);
554 
555       t0 = Mm256UnpackloPsx2<path>(r0, r1);
556       t4 = Mm256UnpackloPsx2<path>(r4, r5);
557       t2 = Mm256UnpackhiPsx2<path>(r0, r1);
558       t6 = Mm256UnpackhiPsx2<path>(r4, r5);
559       t1 = Mm256UnpackloPsx2<path>(r2, r3);
560       t5 = Mm256UnpackloPsx2<path>(r6, r7);
561       t3 = Mm256UnpackhiPsx2<path>(r2, r3);
562       t7 = Mm256UnpackhiPsx2<path>(r6, r7);
563 
564       // The preceding sets of rearrangement operations interleaved by 4 bytes
565       // and then by 8 bytes *within* lanes. The following set interleave by 16
566       // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
567       // are interleaved to create (r0, r1). This complexity follows from the
568       // way that AVX is centered around MM 128-bit lanes.
569       r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
570       r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
571       r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
572       r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
573       r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
574       r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
575       r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
576       r7 = _mm256_permute2f128_ps(t3, t7, 0x31);
577 
578       _mm256_storeu_ps(packed_ptr + 0 * 8, r0);
579       _mm256_storeu_ps(packed_ptr + 2 * 8, r4);
580       _mm256_storeu_ps(packed_ptr + 4 * 8, r1);
581       _mm256_storeu_ps(packed_ptr + 6 * 8, r5);
582       _mm256_storeu_ps(packed_ptr + 1 * 8, r2);
583       _mm256_storeu_ps(packed_ptr + 3 * 8, r6);
584       _mm256_storeu_ps(packed_ptr + 5 * 8, r3);
585       _mm256_storeu_ps(packed_ptr + 7 * 8, r7);
586     } else if (available_src_rows > 0) {
587       const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
588       const __m256i row_mask_v = CompareGreaterThan<path>(
589           _mm256_set1_epi32(available_src_rows), series);
590 
591       __m256 t0, t1, t2, t3, t4, t5, t6, t7;
592       __m256 r0, r1, r2, r3, r4, r5, r6, r7;
593 
594       t0 = _mm256_maskload_ps(src_ptr0, row_mask_v);
595       t4 = _mm256_maskload_ps(src_ptr4, row_mask_v);
596       t1 = _mm256_maskload_ps(src_ptr1, row_mask_v);
597       t5 = _mm256_maskload_ps(src_ptr5, row_mask_v);
598       t2 = _mm256_maskload_ps(src_ptr2, row_mask_v);
599       t6 = _mm256_maskload_ps(src_ptr6, row_mask_v);
600       t3 = _mm256_maskload_ps(src_ptr3, row_mask_v);
601       t7 = _mm256_maskload_ps(src_ptr7, row_mask_v);
602 
603       r0 = _mm256_unpacklo_ps(t0, t1);
604       r4 = _mm256_unpacklo_ps(t4, t5);
605       r2 = _mm256_unpackhi_ps(t0, t1);
606       r6 = _mm256_unpackhi_ps(t4, t5);
607       r1 = _mm256_unpacklo_ps(t2, t3);
608       r5 = _mm256_unpacklo_ps(t6, t7);
609       r3 = _mm256_unpackhi_ps(t2, t3);
610       r7 = _mm256_unpackhi_ps(t6, t7);
611 
612       t0 = Mm256UnpackloPsx2<path>(r0, r1);
613       t4 = Mm256UnpackloPsx2<path>(r4, r5);
614       t2 = Mm256UnpackhiPsx2<path>(r0, r1);
615       t6 = Mm256UnpackhiPsx2<path>(r4, r5);
616       t1 = Mm256UnpackloPsx2<path>(r2, r3);
617       t5 = Mm256UnpackloPsx2<path>(r6, r7);
618       t3 = Mm256UnpackhiPsx2<path>(r2, r3);
619       t7 = Mm256UnpackhiPsx2<path>(r6, r7);
620 
621       // The preceding sets of rearrangement operations interleaved by 4 bytes
622       // and then by 8 bytes *within* lanes. The following set interleave by 16
623       // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
624       // are interleaved to create (r0, r1). This complexity follows from the
625       // way that AVX is centered around MM 128-bit lanes.
626       r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
627       r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
628       r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
629       r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
630       r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
631       r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
632       r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
633       // r7 no longer needed.
634 
635       _mm256_storeu_ps(trailing_buf + 0 * 8, r0);
636       _mm256_storeu_ps(trailing_buf + 2 * 8, r4);
637       _mm256_storeu_ps(trailing_buf + 4 * 8, r1);
638       _mm256_storeu_ps(trailing_buf + 6 * 8, r5);
639       _mm256_storeu_ps(trailing_buf + 1 * 8, r2);
640       _mm256_storeu_ps(trailing_buf + 3 * 8, r6);
641       _mm256_storeu_ps(trailing_buf + 5 * 8, r3);
642       // No store to (trailing_buf + 7 * 8), space not allocated.
643     }
644 
645     packed_ptr += kPackRows * kPackCols;
646     src_ptr0 += src_inc0;
647     src_ptr1 += src_inc1;
648     src_ptr2 += src_inc2;
649     src_ptr3 += src_inc3;
650     src_ptr4 += src_inc4;
651     src_ptr5 += src_inc5;
652     src_ptr6 += src_inc6;
653     src_ptr7 += src_inc7;
654   }
655 }
656 }  // namespace ruy
657 #endif  //  (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
658 
659 #endif  // RUY_RUY_PACK_X86_H_
660