1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_PACK_X86_H_ 17 #define RUY_RUY_PACK_X86_H_ 18 19 #include <cstdint> 20 #include <cstring> 21 #include <type_traits> 22 23 #include "ruy/check_macros.h" 24 #include "ruy/mat.h" 25 #include "ruy/opt_set.h" 26 #include "ruy/pack_common.h" 27 #include "ruy/path.h" 28 #include "ruy/platform.h" 29 #include "ruy/profiler/instrumentation.h" 30 #include "ruy/tune.h" 31 32 namespace ruy { 33 34 #if RUY_PLATFORM_X86 35 36 RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx) 37 RUY_INHERIT_PACK(Path::kAvx, Path::kAvx2Fma) 38 RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512) 39 40 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx2Fma, 8) 41 RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx512, 16) 42 43 template <> 44 struct PackedTypeImpl<Path::kAvx, std::uint8_t> { 45 using Type = std::int8_t; 46 }; 47 48 template <> 49 struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> { 50 using Type = std::int8_t; 51 }; 52 template <> 53 struct PackedTypeImpl<Path::kAvx512, std::uint8_t> { 54 using Type = std::int8_t; 55 }; 56 57 // Note that source and zero buffers can be uint8 type, but in the packing 58 // function are reinterpreted as int8, and are XOR-ed with input_xor. 59 void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, 60 const std::int8_t* zerobuf, int src_stride, 61 int remaining_src_cols, int src_rows, 62 std::int8_t* packed_ptr, std::int32_t* sums_ptr); 63 64 template <typename Scalar> 65 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, 66 Scalar, std::int8_t, std::int32_t, Order::kColMajor> { 67 static_assert(std::is_same<Scalar, std::int8_t>::value || 68 std::is_same<Scalar, std::uint8_t>::value, 69 ""); 70 using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>; 71 static constexpr std::int8_t kInputXor = 72 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 73 74 static void Run(Tuning, const Mat<Scalar>& src_matrix, 75 PMat<std::int8_t>* packed_matrix, int start_col, 76 int end_col) { 77 profiler::ScopeLabel label("Pack (AVX2 8-bit)"); 78 79 RUY_DCHECK(IsColMajor(src_matrix.layout)); 80 RUY_DCHECK(IsColMajor(packed_matrix->layout)); 81 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); 82 RUY_DCHECK_EQ(start_col % Layout::kCols, 0); 83 std::int32_t* sums = packed_matrix->sums; 84 Scalar zerobuf[Layout::kCols * Layout::kRows]; 85 memset(zerobuf, packed_matrix->zero_point ^ kInputXor, 86 Layout::kCols * Layout::kRows * sizeof(Scalar)); 87 for (int block_col = start_col; block_col < end_col; 88 block_col += Layout::kCols) { 89 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; 90 int src_stride = src_matrix.layout.stride; 91 const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; 92 int remaining_src_cols = src_matrix.layout.cols - block_col; 93 94 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. 95 std::int8_t* packed_ptr = 96 packed_matrix->data + 97 packed_matrix->layout.stride * (block_col & block_col_mask); 98 Pack8bitColMajorForAvx2( 99 reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, 100 reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, 101 remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); 102 } 103 } 104 }; 105 106 void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor, 107 const std::int8_t* zerobuf, int src_stride, 108 int remaining_src_cols, int src_rows, 109 std::int8_t* packed_ptr, std::int32_t* sums_ptr); 110 111 template <typename Scalar> 112 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar, 113 std::int8_t, std::int32_t, Order::kColMajor> { 114 static_assert(std::is_same<Scalar, std::int8_t>::value || 115 std::is_same<Scalar, std::uint8_t>::value, 116 ""); 117 using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>; 118 static constexpr std::int8_t kInputXor = 119 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 120 121 static void Run(Tuning, const Mat<Scalar>& src_matrix, 122 PMat<std::int8_t>* packed_matrix, int start_col, 123 int end_col) { 124 profiler::ScopeLabel label("Pack (AVX 8-bit)"); 125 126 RUY_DCHECK(IsColMajor(src_matrix.layout)); 127 RUY_DCHECK(IsColMajor(packed_matrix->layout)); 128 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); 129 RUY_DCHECK_EQ(start_col % Layout::kCols, 0); 130 std::int32_t* sums = packed_matrix->sums; 131 Scalar zerobuf[Layout::kCols * Layout::kRows]; 132 memset(zerobuf, packed_matrix->zero_point ^ kInputXor, 133 Layout::kCols * Layout::kRows * sizeof(Scalar)); 134 for (int block_col = start_col; block_col < end_col; 135 block_col += Layout::kCols) { 136 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; 137 int src_stride = src_matrix.layout.stride; 138 const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; 139 int remaining_src_cols = src_matrix.layout.cols - block_col; 140 141 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. 142 std::int8_t* packed_ptr = 143 packed_matrix->data + 144 packed_matrix->layout.stride * (block_col & block_col_mask); 145 Pack8bitColMajorForAvx( 146 reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, 147 reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, 148 remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); 149 } 150 } 151 }; 152 153 void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf, 154 int src_stride, int remaining_src_cols, 155 int src_rows, float* packed_ptr); 156 157 template <> 158 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, 159 float, float, Order::kColMajor> { 160 using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 161 static void Run(Tuning, const Mat<float>& src_matrix, 162 PMat<float>* packed_matrix, int start_col, int end_col) { 163 profiler::ScopeLabel label("Pack (AVX float)"); 164 165 RUY_DCHECK(IsColMajor(src_matrix.layout)); 166 RUY_DCHECK(IsColMajor(packed_matrix->layout)); 167 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); 168 RUY_DCHECK_EQ(start_col % Layout::kCols, 0); 169 const float zerobuf[Layout::kCols] = { 170 0.0f}; // Remainder default inits to 0.0f. 171 for (int block_col = start_col; block_col < end_col; 172 block_col += Layout::kCols) { 173 int src_stride = src_matrix.layout.stride; 174 const float* src_ptr = src_matrix.data.get() + src_stride * block_col; 175 int remaining_src_cols = src_matrix.layout.cols - block_col; 176 177 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. 178 float* packed_ptr = 179 packed_matrix->data + 180 packed_matrix->layout.stride * (block_col & block_col_mask); 181 PackFloatColMajorForAvx(src_ptr, zerobuf, src_stride, remaining_src_cols, 182 src_matrix.layout.rows, packed_ptr); 183 } 184 } 185 }; 186 187 void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf, 188 int src_stride, int remaining_src_cols, 189 int src_rows, float* packed_ptr); 190 191 template <> 192 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, 193 float, float, float, Order::kColMajor> { 194 using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 195 static void Run(Tuning, const Mat<float>& src_matrix, 196 PMat<float>* packed_matrix, int start_col, int end_col) { 197 profiler::ScopeLabel label("Pack (AVX2 float)"); 198 199 RUY_DCHECK(IsColMajor(src_matrix.layout)); 200 RUY_DCHECK(IsColMajor(packed_matrix->layout)); 201 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); 202 RUY_DCHECK_EQ(start_col % Layout::kCols, 0); 203 const float zerobuf[Layout::kCols] = { 204 0.0f}; // Remainder default inits to 0.0f. 205 for (int block_col = start_col; block_col < end_col; 206 block_col += Layout::kCols) { 207 int src_stride = src_matrix.layout.stride; 208 const float* src_ptr = src_matrix.data.get() + src_stride * block_col; 209 int remaining_src_cols = src_matrix.layout.cols - block_col; 210 211 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. 212 float* packed_ptr = 213 packed_matrix->data + 214 packed_matrix->layout.stride * (block_col & block_col_mask); 215 PackFloatColMajorForAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, 216 src_matrix.layout.rows, packed_ptr); 217 } 218 } 219 }; 220 221 // Note that source and zero buffers can be uint8 type, but in the packing 222 // function are reinterpreted as int8, and are XOR-ed with input_xor. 223 void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, 224 std::int8_t input_xor, 225 const std::int8_t* zerobuf, int src_stride, 226 int remaining_src_cols, int src_rows, 227 std::int8_t* packed_ptr, std::int32_t* sums_ptr); 228 229 template <typename Scalar> 230 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, 231 Scalar, std::int8_t, std::int32_t, Order::kColMajor> { 232 static_assert(std::is_same<Scalar, std::int8_t>::value || 233 std::is_same<Scalar, std::uint8_t>::value, 234 ""); 235 using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>; 236 static constexpr int kHalfLayoutCols = 237 8; // Half the number of cols in a block. 238 static constexpr std::int8_t kInputXor = 239 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 240 241 static void Run(Tuning, const Mat<Scalar>& src_matrix, 242 PMat<std::int8_t>* packed_matrix, int start_col, 243 int end_col) { 244 profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); 245 246 RUY_DCHECK(IsColMajor(src_matrix.layout)); 247 RUY_DCHECK(IsColMajor(packed_matrix->layout)); 248 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); 249 RUY_DCHECK_EQ(start_col % Layout::kCols, 0); 250 RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); 251 std::int32_t* sums = packed_matrix->sums; 252 Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; 253 memset(zerobuf, packed_matrix->zero_point ^ kInputXor, 254 kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); 255 for (int block_col = start_col; block_col < end_col; 256 block_col += Layout::kCols) { 257 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; 258 int src_stride = src_matrix.layout.stride; 259 const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; 260 int remaining_src_cols = src_matrix.layout.cols - block_col; 261 262 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. 263 std::int8_t* packed_ptr = 264 packed_matrix->data + 265 packed_matrix->layout.stride * (block_col & block_col_mask); 266 Pack8bitColMajorForAvx512( 267 reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, 268 reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, 269 remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); 270 } 271 } 272 }; 273 274 void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf, 275 int src_stride, int remaining_src_cols, 276 int src_rows, float* packed_ptr); 277 278 template <> 279 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>, 280 float, float, float, Order::kColMajor> { 281 static void Run(Tuning, const Mat<float>& src_matrix, 282 PMat<float>* packed_matrix, int start_col, int end_col) { 283 profiler::ScopeLabel label("Pack (AVX-512 float)"); 284 using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>; 285 RUY_DCHECK(IsColMajor(src_matrix.layout)); 286 RUY_DCHECK(IsColMajor(packed_matrix->layout)); 287 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); 288 RUY_DCHECK_EQ(start_col % Layout::kCols, 0); 289 const float zerobuf[Layout::kCols] = { 290 0.0f}; // Remainder default inits to 0.0f. 291 for (int block_col = start_col; block_col < end_col; 292 block_col += Layout::kCols) { 293 int src_stride = src_matrix.layout.stride; 294 const float* src_ptr = src_matrix.data.get() + src_stride * block_col; 295 int remaining_src_cols = src_matrix.layout.cols - block_col; 296 297 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. 298 float* packed_ptr = 299 packed_matrix->data + 300 packed_matrix->layout.stride * (block_col & block_col_mask); 301 PackFloatColMajorForAvx512(src_ptr, zerobuf, src_stride, 302 remaining_src_cols, src_matrix.layout.rows, 303 packed_ptr); 304 } 305 } 306 }; 307 308 void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride, 309 int src_zero_point, std::int8_t* packed_ptr, 310 int packed_stride, int start_col, int end_col, 311 int src_cols, int block_row, int src_rows, 312 int input_xor, std::int32_t* sums); 313 314 template <typename Scalar> 315 struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, 316 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { 317 static void Run(Tuning, const Mat<Scalar>& src_matrix, 318 PMat<std::int8_t>* packed_matrix, int start_col, 319 int end_col) { 320 profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)"); 321 RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); 322 static constexpr int kInputXor = 323 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 324 std::int32_t* sums = packed_matrix->sums; 325 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); 326 int block_row = 0; 327 for (; block_row < packed_matrix->layout.rows; block_row += 4) { 328 int src_stride = src_matrix.layout.stride; 329 int packed_stride = packed_matrix->layout.stride; 330 const Scalar* src_ptr = 331 src_matrix.data.get() + block_row * src_stride + start_col; 332 std::int8_t* packed_ptr = 333 packed_matrix->data + start_col * packed_stride + block_row * 8; 334 Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr), 335 src_stride, src_matrix.zero_point, packed_ptr, 336 packed_stride, start_col, end_col, 337 src_matrix.layout.cols, block_row, 338 src_matrix.layout.rows, kInputXor, sums); 339 } 340 } 341 }; 342 343 void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride, 344 int src_zero_point, std::int8_t* packed_ptr, 345 int packed_stride, int start_col, int end_col, 346 int src_cols, int block_row, int src_rows, 347 int input_xor, std::int32_t* sums); 348 349 template <typename Scalar> 350 struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar, 351 std::int8_t, std::int32_t, Order::kRowMajor> { 352 static void Run(Tuning, const Mat<Scalar>& src_matrix, 353 PMat<std::int8_t>* packed_matrix, int start_col, 354 int end_col) { 355 profiler::ScopeLabel label("Pack (AVX 8bit row-major)"); 356 RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); 357 static constexpr int kInputXor = 358 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 359 std::int32_t* sums = packed_matrix->sums; 360 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); 361 int block_row = 0; 362 for (; block_row < packed_matrix->layout.rows; block_row += 4) { 363 int src_stride = src_matrix.layout.stride; 364 int packed_stride = packed_matrix->layout.stride; 365 const Scalar* src_ptr = 366 src_matrix.data.get() + block_row * src_stride + start_col; 367 std::int8_t* packed_ptr = 368 packed_matrix->data + start_col * packed_stride + block_row * 8; 369 Pack8bitRowMajorForAvx(reinterpret_cast<const std::uint8_t*>(src_ptr), 370 src_stride, src_matrix.zero_point, packed_ptr, 371 packed_stride, start_col, end_col, 372 src_matrix.layout.cols, block_row, 373 src_matrix.layout.rows, kInputXor, sums); 374 } 375 } 376 }; 377 378 void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride, 379 int src_zero_point, std::int8_t* packed_ptr, 380 int packed_stride, int start_col, int end_col, 381 int src_cols, int block_row, int src_rows, 382 int input_xor, std::int32_t* sums); 383 384 template <typename Scalar> 385 struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, 386 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { 387 static void Run(Tuning, const Mat<Scalar>& src_matrix, 388 PMat<std::int8_t>* packed_matrix, int start_col, 389 int end_col) { 390 profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)"); 391 RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); 392 static constexpr int kInputXor = 393 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 394 std::int32_t* sums = packed_matrix->sums; 395 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); 396 int block_row = 0; 397 for (; block_row < packed_matrix->layout.rows; block_row += 4) { 398 int src_stride = src_matrix.layout.stride; 399 int packed_stride = packed_matrix->layout.stride; 400 const Scalar* src_ptr = 401 src_matrix.data.get() + block_row * src_stride + start_col; 402 std::int8_t* packed_ptr = 403 packed_matrix->data + start_col * packed_stride + block_row * 16; 404 Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr), 405 src_stride, src_matrix.zero_point, packed_ptr, 406 packed_stride, start_col, end_col, 407 src_matrix.layout.cols, block_row, 408 src_matrix.layout.rows, kInputXor, sums); 409 } 410 } 411 }; 412 #endif // RUY_PLATFORM_X86 413 414 } // namespace ruy 415 416 #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)) 417 418 #include <immintrin.h> // IWYU pragma: keep 419 420 namespace ruy { 421 namespace { 422 423 template <Path path> 424 inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { 425 return _mm256_castpd_ps( 426 _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); 427 } 428 429 template <Path path> 430 inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { 431 return _mm256_castpd_ps( 432 _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); 433 } 434 435 template <Path path> 436 inline __m256i CompareGreaterThan(const __m256i&, const __m256i&) { 437 RUY_DCHECK(false); 438 return _mm256_set1_epi32(0); 439 } 440 441 // Shared between AVX and AVX2+FMA. 442 template <Path path> 443 inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, 444 const std::int8_t* addr) { 445 RUY_DCHECK_LT(available_src_rows, 32); 446 __m256i padded_data; 447 448 if (available_src_rows >= 16) { 449 __m128i load_hi = _mm_set1_epi8(zero_point); 450 __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr)); 451 memcpy(&load_hi, addr + 16, available_src_rows - 16); 452 padded_data = _mm256_set_m128i(load_hi, load_lo); 453 } else { 454 __m128i load_hi = _mm_set1_epi8(zero_point); 455 __m128i load_lo = load_hi; 456 memcpy(&load_lo, addr, available_src_rows); 457 padded_data = _mm256_set_m128i(load_hi, load_lo); 458 } 459 return padded_data; 460 } 461 462 } // namespace. 463 464 template <typename PackImpl, Path path> 465 inline void PackFloatColMajorForAvxCommonPacker(const float* src_ptr, 466 const float* zerobuf, 467 int src_stride, 468 int remaining_src_cols, 469 int src_rows, float* packed_ptr, 470 float* trailing_buf) { 471 RUY_DCHECK_EQ(PackImpl::Layout::kCols, 8); 472 RUY_DCHECK_EQ(PackImpl::Layout::kRows, 1); 473 474 // This packing amounts to transposition of 8x8 blocks. 475 static constexpr int kPackCols = 8; // Source cols packed together. 476 static constexpr int kPackRows = 8; // Short input is padded. 477 478 const float* src_ptr0 = src_ptr; 479 const float* src_ptr1 = src_ptr0 + src_stride; 480 const float* src_ptr2 = src_ptr1 + src_stride; 481 const float* src_ptr3 = src_ptr2 + src_stride; 482 const float* src_ptr4 = src_ptr3 + src_stride; 483 const float* src_ptr5 = src_ptr4 + src_stride; 484 const float* src_ptr6 = src_ptr5 + src_stride; 485 const float* src_ptr7 = src_ptr6 + src_stride; 486 std::int64_t src_inc0 = 8; 487 std::int64_t src_inc1 = 8; 488 std::int64_t src_inc2 = 8; 489 std::int64_t src_inc3 = 8; 490 std::int64_t src_inc4 = 8; 491 std::int64_t src_inc5 = 8; 492 std::int64_t src_inc6 = 8; 493 std::int64_t src_inc7 = 8; 494 // Handle cases where source does not have kPackDim (8) columns. 495 if (remaining_src_cols < kPackCols) { 496 if (remaining_src_cols <= 0) { 497 src_ptr0 = zerobuf; 498 src_inc0 = 0; 499 } 500 if (remaining_src_cols <= 1) { 501 src_ptr1 = zerobuf; 502 src_inc1 = 0; 503 } 504 if (remaining_src_cols <= 2) { 505 src_ptr2 = zerobuf; 506 src_inc2 = 0; 507 } 508 if (remaining_src_cols <= 3) { 509 src_ptr3 = zerobuf; 510 src_inc3 = 0; 511 } 512 if (remaining_src_cols <= 4) { 513 src_ptr4 = zerobuf; 514 src_inc4 = 0; 515 } 516 if (remaining_src_cols <= 5) { 517 src_ptr5 = zerobuf; 518 src_inc5 = 0; 519 } 520 if (remaining_src_cols <= 6) { 521 src_ptr6 = zerobuf; 522 src_inc6 = 0; 523 } 524 src_ptr7 = zerobuf; 525 src_inc7 = 0; 526 } 527 528 for (int k = 0; k < src_rows; k += kPackRows) { 529 const int available_src_rows = src_rows - k; 530 // Effectively, 531 // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); 532 // but treat each case separately. 533 if (available_src_rows >= kPackRows) { 534 __m256 t0, t1, t2, t3, t4, t5, t6, t7; 535 __m256 r0, r1, r2, r3, r4, r5, r6, r7; 536 537 t0 = _mm256_loadu_ps(src_ptr0); 538 t4 = _mm256_loadu_ps(src_ptr4); 539 t1 = _mm256_loadu_ps(src_ptr1); 540 t5 = _mm256_loadu_ps(src_ptr5); 541 t2 = _mm256_loadu_ps(src_ptr2); 542 t6 = _mm256_loadu_ps(src_ptr6); 543 t3 = _mm256_loadu_ps(src_ptr3); 544 t7 = _mm256_loadu_ps(src_ptr7); 545 546 r0 = _mm256_unpacklo_ps(t0, t1); 547 r4 = _mm256_unpacklo_ps(t4, t5); 548 r2 = _mm256_unpackhi_ps(t0, t1); 549 r6 = _mm256_unpackhi_ps(t4, t5); 550 r1 = _mm256_unpacklo_ps(t2, t3); 551 r5 = _mm256_unpacklo_ps(t6, t7); 552 r3 = _mm256_unpackhi_ps(t2, t3); 553 r7 = _mm256_unpackhi_ps(t6, t7); 554 555 t0 = Mm256UnpackloPsx2<path>(r0, r1); 556 t4 = Mm256UnpackloPsx2<path>(r4, r5); 557 t2 = Mm256UnpackhiPsx2<path>(r0, r1); 558 t6 = Mm256UnpackhiPsx2<path>(r4, r5); 559 t1 = Mm256UnpackloPsx2<path>(r2, r3); 560 t5 = Mm256UnpackloPsx2<path>(r6, r7); 561 t3 = Mm256UnpackhiPsx2<path>(r2, r3); 562 t7 = Mm256UnpackhiPsx2<path>(r6, r7); 563 564 // The preceding sets of rearrangement operations interleaved by 4 bytes 565 // and then by 8 bytes *within* lanes. The following set interleave by 16 566 // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) 567 // are interleaved to create (r0, r1). This complexity follows from the 568 // way that AVX is centered around MM 128-bit lanes. 569 r0 = _mm256_permute2f128_ps(t0, t4, 0x20); 570 r4 = _mm256_permute2f128_ps(t1, t5, 0x20); 571 r1 = _mm256_permute2f128_ps(t0, t4, 0x31); 572 r5 = _mm256_permute2f128_ps(t1, t5, 0x31); 573 r2 = _mm256_permute2f128_ps(t2, t6, 0x20); 574 r6 = _mm256_permute2f128_ps(t3, t7, 0x20); 575 r3 = _mm256_permute2f128_ps(t2, t6, 0x31); 576 r7 = _mm256_permute2f128_ps(t3, t7, 0x31); 577 578 _mm256_storeu_ps(packed_ptr + 0 * 8, r0); 579 _mm256_storeu_ps(packed_ptr + 2 * 8, r4); 580 _mm256_storeu_ps(packed_ptr + 4 * 8, r1); 581 _mm256_storeu_ps(packed_ptr + 6 * 8, r5); 582 _mm256_storeu_ps(packed_ptr + 1 * 8, r2); 583 _mm256_storeu_ps(packed_ptr + 3 * 8, r6); 584 _mm256_storeu_ps(packed_ptr + 5 * 8, r3); 585 _mm256_storeu_ps(packed_ptr + 7 * 8, r7); 586 } else if (available_src_rows > 0) { 587 const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); 588 const __m256i row_mask_v = CompareGreaterThan<path>( 589 _mm256_set1_epi32(available_src_rows), series); 590 591 __m256 t0, t1, t2, t3, t4, t5, t6, t7; 592 __m256 r0, r1, r2, r3, r4, r5, r6, r7; 593 594 t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); 595 t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); 596 t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); 597 t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); 598 t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); 599 t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); 600 t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); 601 t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); 602 603 r0 = _mm256_unpacklo_ps(t0, t1); 604 r4 = _mm256_unpacklo_ps(t4, t5); 605 r2 = _mm256_unpackhi_ps(t0, t1); 606 r6 = _mm256_unpackhi_ps(t4, t5); 607 r1 = _mm256_unpacklo_ps(t2, t3); 608 r5 = _mm256_unpacklo_ps(t6, t7); 609 r3 = _mm256_unpackhi_ps(t2, t3); 610 r7 = _mm256_unpackhi_ps(t6, t7); 611 612 t0 = Mm256UnpackloPsx2<path>(r0, r1); 613 t4 = Mm256UnpackloPsx2<path>(r4, r5); 614 t2 = Mm256UnpackhiPsx2<path>(r0, r1); 615 t6 = Mm256UnpackhiPsx2<path>(r4, r5); 616 t1 = Mm256UnpackloPsx2<path>(r2, r3); 617 t5 = Mm256UnpackloPsx2<path>(r6, r7); 618 t3 = Mm256UnpackhiPsx2<path>(r2, r3); 619 t7 = Mm256UnpackhiPsx2<path>(r6, r7); 620 621 // The preceding sets of rearrangement operations interleaved by 4 bytes 622 // and then by 8 bytes *within* lanes. The following set interleave by 16 623 // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) 624 // are interleaved to create (r0, r1). This complexity follows from the 625 // way that AVX is centered around MM 128-bit lanes. 626 r0 = _mm256_permute2f128_ps(t0, t4, 0x20); 627 r4 = _mm256_permute2f128_ps(t1, t5, 0x20); 628 r1 = _mm256_permute2f128_ps(t0, t4, 0x31); 629 r5 = _mm256_permute2f128_ps(t1, t5, 0x31); 630 r2 = _mm256_permute2f128_ps(t2, t6, 0x20); 631 r6 = _mm256_permute2f128_ps(t3, t7, 0x20); 632 r3 = _mm256_permute2f128_ps(t2, t6, 0x31); 633 // r7 no longer needed. 634 635 _mm256_storeu_ps(trailing_buf + 0 * 8, r0); 636 _mm256_storeu_ps(trailing_buf + 2 * 8, r4); 637 _mm256_storeu_ps(trailing_buf + 4 * 8, r1); 638 _mm256_storeu_ps(trailing_buf + 6 * 8, r5); 639 _mm256_storeu_ps(trailing_buf + 1 * 8, r2); 640 _mm256_storeu_ps(trailing_buf + 3 * 8, r6); 641 _mm256_storeu_ps(trailing_buf + 5 * 8, r3); 642 // No store to (trailing_buf + 7 * 8), space not allocated. 643 } 644 645 packed_ptr += kPackRows * kPackCols; 646 src_ptr0 += src_inc0; 647 src_ptr1 += src_inc1; 648 src_ptr2 += src_inc2; 649 src_ptr3 += src_inc3; 650 src_ptr4 += src_inc4; 651 src_ptr5 += src_inc5; 652 src_ptr6 += src_inc6; 653 src_ptr7 += src_inc7; 654 } 655 } 656 } // namespace ruy 657 #endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM) 658 659 #endif // RUY_RUY_PACK_X86_H_ 660