1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_KERNEL_X86_H_ 17 #define RUY_RUY_KERNEL_X86_H_ 18 19 #include <cstdint> 20 #include <cstring> 21 22 #include "ruy/kernel_common.h" 23 #include "ruy/mat.h" 24 #include "ruy/mul_params.h" 25 #include "ruy/opt_set.h" 26 #include "ruy/path.h" 27 #include "ruy/platform.h" 28 #include "ruy/tune.h" 29 30 namespace ruy { 31 32 #if RUY_PLATFORM_X86 33 34 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma) 35 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx) 36 RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512) 37 38 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); 39 void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); 40 41 template <typename DstScalar> 42 struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> { 43 static constexpr Path kPath = Path::kAvx512; 44 Tuning tuning = Tuning::kAuto; 45 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; 46 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; 47 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 48 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 49 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 50 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 51 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 52 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 53 end_col, dst, ¶ms); 54 if (dst->layout.cols == 1 && 55 mul_params.channel_dimension() == ChannelDimension::kRow) { 56 Kernel8bitAvx512SingleCol(params); 57 } else { 58 Kernel8bitAvx512(params); 59 } 60 } 61 }; 62 63 void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); 64 void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); 65 66 template <> 67 struct Kernel<Path::kAvx512, float, float, float, float> { 68 static constexpr Path kPath = Path::kAvx512; 69 Tuning tuning = Tuning::kAuto; 70 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>; 71 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>; 72 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 73 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 74 const MulParams<float, float>& mul_params, int start_row, 75 int start_col, int end_row, int end_col, Mat<float>* dst) const { 76 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; 77 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 78 end_col, dst, ¶ms); 79 if (dst->layout.cols == 1 && 80 mul_params.channel_dimension() == ChannelDimension::kRow) { 81 KernelFloatAvx512SingleCol(params); 82 } else { 83 KernelFloatAvx512(params); 84 } 85 } 86 }; 87 88 void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); 89 void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); 90 91 template <typename DstScalar> 92 struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, 93 DstScalar> { 94 static constexpr Path kPath = Path::kAvx2Fma; 95 Tuning tuning = Tuning::kAuto; 96 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 97 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 98 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 99 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 100 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 101 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 102 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 103 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 104 end_col, dst, ¶ms); 105 if (dst->layout.cols == 1 && 106 mul_params.channel_dimension() == ChannelDimension::kRow) { 107 Kernel8bitAvx2SingleCol(params); 108 } else { 109 Kernel8bitAvx2(params); 110 } 111 } 112 }; 113 114 void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); 115 void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); 116 117 template <> 118 struct Kernel<Path::kAvx2Fma, float, float, float, float> { 119 static constexpr Path kPath = Path::kAvx2Fma; 120 Tuning tuning = Tuning::kAuto; 121 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 122 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 123 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 124 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 125 const MulParams<float, float>& mul_params, int start_row, 126 int start_col, int end_row, int end_col, Mat<float>* dst) const { 127 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; 128 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 129 end_col, dst, ¶ms); 130 if (dst->layout.cols == 1 && 131 mul_params.channel_dimension() == ChannelDimension::kRow) { 132 KernelFloatAvx2SingleCol(params); 133 } else { 134 KernelFloatAvx2(params); 135 } 136 } 137 }; 138 139 void KernelFloatAvx(const KernelParamsFloat<8, 8>& params); 140 void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params); 141 142 template <> 143 struct Kernel<Path::kAvx, float, float, float, float> { 144 static constexpr Path kPath = Path::kAvx; 145 Tuning tuning = Tuning::kAuto; 146 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 147 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; 148 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 149 void Run(const PMat<float>& lhs, const PMat<float>& rhs, 150 const MulParams<float, float>& mul_params, int start_row, 151 int start_col, int end_row, int end_col, Mat<float>* dst) const { 152 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; 153 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, 154 end_col, dst, ¶ms); 155 if (dst->layout.cols == 1 && 156 mul_params.channel_dimension() == ChannelDimension::kRow) { 157 KernelFloatAvxSingleCol(params); 158 } else { 159 KernelFloatAvx(params); 160 } 161 } 162 }; 163 164 void Kernel8bitAvx(const KernelParams8bit<8, 8>& params); 165 void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params); 166 167 template <typename DstScalar> 168 struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> { 169 static constexpr Path kPath = Path::kAvx2Fma; 170 Tuning tuning = Tuning::kAuto; 171 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 172 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; 173 explicit Kernel(Tuning tuning_) : tuning(tuning_) {} 174 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, 175 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, 176 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { 177 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; 178 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, 179 end_col, dst, ¶ms); 180 if (dst->layout.cols == 1 && 181 mul_params.channel_dimension() == ChannelDimension::kRow) { 182 Kernel8bitAvxSingleCol(params); 183 } else { 184 Kernel8bitAvx(params); 185 } 186 } 187 }; 188 189 #endif // RUY_PLATFORM_X86 190 } // namespace ruy 191 192 #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)) 193 194 #include <immintrin.h> // IWYU pragma: keep 195 196 namespace ruy { 197 namespace { 198 namespace intrin_utils { 199 200 // Defined as a template so clang won't detect it as an uneeded 201 // definition. 202 template <Path path> 203 inline float mm256_get1_ps(const __m256 a, int i) { 204 __m256i ai = _mm256_castps_si256(a); 205 int float_val_as_int; 206 switch (i) { 207 case 0: 208 float_val_as_int = _mm256_extract_epi32(ai, 0); 209 break; 210 case 1: 211 float_val_as_int = _mm256_extract_epi32(ai, 1); 212 break; 213 case 2: 214 float_val_as_int = _mm256_extract_epi32(ai, 2); 215 break; 216 case 3: 217 float_val_as_int = _mm256_extract_epi32(ai, 3); 218 break; 219 case 4: 220 float_val_as_int = _mm256_extract_epi32(ai, 4); 221 break; 222 case 5: 223 float_val_as_int = _mm256_extract_epi32(ai, 5); 224 break; 225 case 6: 226 float_val_as_int = _mm256_extract_epi32(ai, 6); 227 break; 228 case 7: 229 float_val_as_int = _mm256_extract_epi32(ai, 7); 230 break; 231 default: 232 RUY_DCHECK_LT(i, 8); 233 return .0f; 234 } 235 float float_val; 236 std::memcpy(&float_val, &float_val_as_int, sizeof(float_val)); 237 return float_val; 238 } 239 240 // Defined as a template so clang won't detect it as an uneeded 241 // definition. 242 template <Path path> 243 inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { 244 for (int i = 0; i < residual_rows; ++i) { 245 dst[i] = intrin_utils::mm256_get1_ps<path>(v, i); 246 } 247 } 248 249 template <Path path> 250 inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) { 251 // Specializations added for AVX and AVX2FMA paths in their respective kernel 252 // files. 253 RUY_DCHECK(false); 254 return _mm256_set1_ps(0); 255 } 256 257 template <Path path> 258 inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) { 259 // Specializations added for AVX and AVX2FMA paths in their respective kernel 260 // files. 261 RUY_DCHECK(false); 262 return _mm256_set1_epi32(0); 263 } 264 265 // Polyfill for _mm_storeu_si16(dst, v). 266 template <Path path> 267 inline void mm_storeu_si16(void* dst, __m128i v) { 268 #if (defined __clang__) || (defined _MSC_VER) 269 _mm_storeu_si16(dst, v); 270 #else 271 // GCC 9 lacks support for __mm_storeu_si16. 272 *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0); 273 #endif 274 } 275 276 // Polyfill for _mm_storeu_si32(dst, v). 277 template <Path path> 278 inline void mm_storeu_si32(void* dst, __m128i v) { 279 #if (defined __clang__) || (defined _MSC_VER) 280 _mm_storeu_si32(dst, v); 281 #else 282 // GCC 9 lacks support for __mm_storeu_si32. 283 *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0); 284 #endif 285 } 286 287 // Polyfill for _mm_loadu_si32(src). 288 template <Path path> 289 inline __m128i mm_loadu_si32(const void* src) { 290 #if (defined __clang__) || (defined _MSC_VER) 291 return _mm_loadu_si32(src); 292 #else 293 // GCC 9 lacks support for _mm_loadu_si32. 294 __m128i res; 295 asm("movss %[src], %[res]" 296 : [res] "=x"(res) 297 : [src] "m"(*static_cast<const int*>(src))); 298 return res; 299 #endif 300 } 301 302 template <Path path> 303 inline __m128i mm256_extracti128_si256(const __m256i&, const int) { 304 RUY_DCHECK(false); 305 return _mm_setzero_si128(); 306 } 307 308 template <Path path> 309 inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, 310 const __m256i v) { 311 // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. 312 const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); 313 __m256i shuffled_v; 314 if (residual_rows > 1) { 315 // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 316 // in each 128-bit lane. 317 shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm); 318 } 319 switch (residual_rows) { 320 case 0: 321 break; 322 case 1: 323 dst[0] = _mm256_extract_epi8(v, 0); 324 break; 325 case 2: 326 mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 327 break; 328 case 3: { 329 __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0); 330 mm_storeu_si16<path>(dst, trailing_packed); 331 dst[2] = _mm_extract_epi8(trailing_packed, 2); 332 break; 333 } 334 case 4: 335 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 336 break; 337 case 5: 338 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 339 dst[4] = _mm256_extract_epi8(shuffled_v, 16); 340 break; 341 case 6: 342 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 343 mm_storeu_si16<path>(dst + 4, 344 mm256_extracti128_si256<path>(shuffled_v, 1)); 345 break; 346 case 7: { 347 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 348 __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); 349 mm_storeu_si16<path>(dst + 4, trailing_packed); 350 dst[6] = _mm_extract_epi8(trailing_packed, 2); 351 break; 352 } 353 case 8: 354 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 355 mm_storeu_si32<path>(dst + 4, 356 mm256_extracti128_si256<path>(shuffled_v, 1)); 357 break; 358 default: 359 RUY_DCHECK_LE(residual_rows, 8); 360 break; 361 } 362 } 363 364 template <Path path> 365 inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) { 366 // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. 367 const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); 368 const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); 369 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 370 mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); 371 } 372 373 template <Path path> 374 inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, 375 const __m256i v) { 376 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( 377 reinterpret_cast<std::uint8_t*>(dst), residual_rows, v); 378 } 379 380 template <Path path> 381 inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) { 382 // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. 383 const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); 384 const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); 385 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 386 mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); 387 } 388 389 template <Path path> 390 inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, 391 const __m256i v) { 392 // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively 393 // truncating each 16-bit integer. 394 const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); 395 __m256i shuffled_v; 396 __m128i shuffled_v_low; 397 if (residual_rows > 1) { 398 shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); 399 shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0); 400 } else { 401 shuffled_v_low = mm256_extracti128_si256<path>(v, 0); 402 } 403 switch (residual_rows) { 404 case 0: 405 break; 406 case 1: 407 mm_storeu_si16<path>(dst, shuffled_v_low); 408 break; 409 case 2: 410 mm_storeu_si32<path>(dst, shuffled_v_low); 411 break; 412 case 3: { 413 mm_storeu_si32<path>(dst, shuffled_v_low); 414 dst[2] = _mm_extract_epi16(shuffled_v_low, 2); 415 break; 416 } 417 case 4: 418 _mm_storeu_si64(dst, shuffled_v_low); 419 break; 420 case 5: 421 _mm_storeu_si64(dst, shuffled_v_low); 422 dst[4] = _mm256_extract_epi16(shuffled_v, 8); 423 break; 424 case 6: 425 _mm_storeu_si64(dst, shuffled_v_low); 426 mm_storeu_si32<path>(dst + 4, 427 mm256_extracti128_si256<path>(shuffled_v, 1)); 428 break; 429 case 7: { 430 _mm_storeu_si64(dst, shuffled_v_low); 431 __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); 432 mm_storeu_si32<path>(dst + 4, trailing_packed); 433 dst[6] = _mm_extract_epi16(trailing_packed, 2); 434 break; 435 } 436 case 8: 437 _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 438 _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); 439 break; 440 default: 441 RUY_DCHECK_LE(residual_rows, 8); 442 break; 443 } 444 } 445 446 template <Path path> 447 inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) { 448 // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively 449 // truncating each 16-bit integer. 450 const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); 451 const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); 452 _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); 453 _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); 454 } 455 456 template <Path path> 457 inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, 458 const __m256i v) { 459 const __m128i v_low = mm256_extracti128_si256<path>(v, 0); 460 switch (residual_rows) { 461 case 0: 462 break; 463 case 1: 464 mm_storeu_si32<path>(dst, v_low); 465 break; 466 case 2: 467 _mm_storeu_si64(dst, v_low); 468 break; 469 case 3: { 470 __m128i trailing_packed = v_low; 471 _mm_storeu_si64(dst, trailing_packed); 472 dst[2] = _mm_extract_epi32(trailing_packed, 2); 473 break; 474 } 475 case 4: 476 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); 477 break; 478 case 5: 479 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); 480 dst[4] = _mm256_extract_epi32(v, 4); 481 break; 482 case 6: 483 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); 484 _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1)); 485 break; 486 case 7: { 487 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); 488 __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1); 489 _mm_storeu_si64(dst + 4, trailing_packed); 490 dst[6] = _mm_extract_epi32(trailing_packed, 2); 491 break; 492 } 493 case 8: 494 _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); 495 break; 496 default: 497 RUY_DCHECK_LE(residual_rows, 8); 498 break; 499 } 500 } 501 502 template <Path path> 503 inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) { 504 _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); 505 } 506 507 // Transpose a 8x8 matrix of floats. 508 template <Path path> 509 void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3, 510 __m256* v4, __m256* v5, __m256* v6, __m256* v7) { 511 __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1); 512 __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1); 513 __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3); 514 __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3); 515 __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5); 516 __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5); 517 __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7); 518 __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7); 519 __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0)); 520 __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2)); 521 __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0)); 522 __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2)); 523 __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0)); 524 __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2)); 525 __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0)); 526 __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2)); 527 *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20); 528 *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20); 529 *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20); 530 *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20); 531 *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31); 532 *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31); 533 *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31); 534 *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31); 535 } 536 537 // Transpose a 8x8 matrix of int32's. 538 template <Path path> 539 void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2, 540 __m256i* v3, __m256i* v4, __m256i* v5, 541 __m256i* v6, __m256i* v7) { 542 mm256_transpose8x8_ps<path>( 543 reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1), 544 reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3), 545 reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5), 546 reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7)); 547 } 548 549 } // namespace intrin_utils 550 } // namespace 551 552 template <Path path> 553 inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { 554 // As parameters are defined, we need to scale by sizeof(float). 555 const std::int64_t lhs_stride = params.lhs_stride >> 2; 556 const std::int64_t dst_stride = params.dst_stride >> 2; 557 const std::int64_t rhs_stride = params.rhs_stride >> 2; 558 // 559 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; 560 // AVX2 float block size = 8. 561 const int end_row = std::min(params.dst_rows, params.last_row + 8); 562 const int end_col = std::min(params.dst_cols, params.last_col + 8); 563 // 564 const float* adj_rhs_col_ptr = 565 params.rhs_base_ptr - params.start_col * rhs_stride; 566 float* adj_dst_col_ptr = 567 params.dst_base_ptr - params.start_col * dst_stride - params.start_row; 568 const float* adj_lhs_col_ptr = 569 params.lhs_base_ptr - params.start_row * lhs_stride; 570 const float* bias_ptr = params.bias; 571 572 const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); 573 const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); 574 const bool channel_dimension_is_col = 575 params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; 576 577 int col = params.start_col; 578 // Loop through cols by float block size, leaving incomplete remainder 579 for (; col <= end_col - 8; col += 8) { 580 __m256 accum_data_v[8]; 581 582 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; 583 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; 584 585 for (int row = params.start_row; row < end_row; row += 8) { 586 const int residual_rows = std::min(end_row - row, 8); 587 588 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; 589 float* dst_ptr = dst_col_ptr + row; 590 591 // Initialize with bias. 592 if (channel_dimension_is_col) { 593 const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; 594 for (int j = 0; j < 8; ++j) { 595 accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); 596 } 597 } else { 598 const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; 599 const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); 600 601 for (int j = 0; j < 8; ++j) { 602 accum_data_v[j] = initial_accum_data; 603 } 604 } 605 606 const float* lhs_ptr = lhs_col_ptr; 607 const float* rhs_ptr = rhs_col_ptr; 608 for (int d = 0; d < params.depth; ++d) { 609 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); 610 // Load 8 RHS values, then use permute instructions to broadcast each 611 // value to a register. _mm256_permute2f128_ps is slow on AMD. 612 __m256 rhs0_3 = 613 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); 614 __m256 rhs4_7 = 615 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4)); 616 617 const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); 618 accum_data_v[0] = intrin_utils::MulAdd<path>( 619 lhs_data, dup_rhs_element_0, accum_data_v[0]); 620 621 const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); 622 accum_data_v[1] = intrin_utils::MulAdd<path>( 623 lhs_data, dup_rhs_element_1, accum_data_v[1]); 624 625 const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); 626 accum_data_v[2] = intrin_utils::MulAdd<path>( 627 lhs_data, dup_rhs_element_2, accum_data_v[2]); 628 629 const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); 630 accum_data_v[3] = intrin_utils::MulAdd<path>( 631 lhs_data, dup_rhs_element_3, accum_data_v[3]); 632 633 const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); 634 accum_data_v[4] = intrin_utils::MulAdd<path>( 635 lhs_data, dup_rhs_element_4, accum_data_v[4]); 636 637 const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); 638 accum_data_v[5] = intrin_utils::MulAdd<path>( 639 lhs_data, dup_rhs_element_5, accum_data_v[5]); 640 641 const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); 642 accum_data_v[6] = intrin_utils::MulAdd<path>( 643 lhs_data, dup_rhs_element_6, accum_data_v[6]); 644 645 const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); 646 accum_data_v[7] = intrin_utils::MulAdd<path>( 647 lhs_data, dup_rhs_element_7, accum_data_v[7]); 648 649 lhs_ptr += 8; 650 rhs_ptr += 8; 651 } 652 653 if (residual_rows == 8) { 654 for (int j = 0; j < 8; ++j) { 655 float* block_ptr = dst_ptr + j * dst_stride; 656 accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); 657 accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); 658 _mm256_storeu_ps(block_ptr, accum_data_v[j]); 659 } 660 } else { 661 for (int j = 0; j < 8; ++j) { 662 float* block_ptr = dst_ptr + j * dst_stride; 663 accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); 664 accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); 665 intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, 666 accum_data_v[j]); 667 } 668 } 669 } // End row-block loop. 670 } // End col-block loop. 671 672 if (col < end_col) { 673 // Remaining cols in [0, float block size). 674 RUY_DCHECK_GE(end_col - col, 0); 675 RUY_DCHECK_LT(end_col - col, 8); 676 677 __m256 accum_data_v[8]; 678 679 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; 680 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; 681 const int residual_cols = std::min(end_col - col, 8); 682 683 for (int row = params.start_row; row < end_row; row += 8) { 684 const int residual_rows = std::min(end_row - row, 8); 685 686 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; 687 float* dst_ptr = dst_col_ptr + row; 688 689 // Initialize with bias. 690 if (channel_dimension_is_col) { 691 const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; 692 for (int j = 0; j < 8; ++j) { 693 accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); 694 } 695 } else { 696 const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; 697 const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); 698 699 for (int j = 0; j < 8; ++j) { 700 accum_data_v[j] = initial_accum_data; 701 } 702 } 703 704 const float* lhs_ptr = lhs_col_ptr; 705 const float* rhs_ptr = rhs_col_ptr; 706 for (int d = 0; d < params.depth; ++d) { 707 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); 708 709 __m256 rhs0_3 = 710 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); 711 __m256 rhs4_7 = 712 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4)); 713 714 const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); 715 accum_data_v[0] = intrin_utils::MulAdd<path>( 716 lhs_data, dup_rhs_element_0, accum_data_v[0]); 717 718 const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); 719 accum_data_v[1] = intrin_utils::MulAdd<path>( 720 lhs_data, dup_rhs_element_1, accum_data_v[1]); 721 722 const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); 723 accum_data_v[2] = intrin_utils::MulAdd<path>( 724 lhs_data, dup_rhs_element_2, accum_data_v[2]); 725 726 const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); 727 accum_data_v[3] = intrin_utils::MulAdd<path>( 728 lhs_data, dup_rhs_element_3, accum_data_v[3]); 729 730 const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); 731 accum_data_v[4] = intrin_utils::MulAdd<path>( 732 lhs_data, dup_rhs_element_4, accum_data_v[4]); 733 734 const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); 735 accum_data_v[5] = intrin_utils::MulAdd<path>( 736 lhs_data, dup_rhs_element_5, accum_data_v[5]); 737 738 const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); 739 accum_data_v[6] = intrin_utils::MulAdd<path>( 740 lhs_data, dup_rhs_element_6, accum_data_v[6]); 741 742 const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); 743 accum_data_v[7] = intrin_utils::MulAdd<path>( 744 lhs_data, dup_rhs_element_7, accum_data_v[7]); 745 746 lhs_ptr += 8; 747 rhs_ptr += 8; 748 } 749 750 for (int j = 0; j < residual_cols; ++j) { 751 float* block_ptr = dst_ptr + j * dst_stride; 752 accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); 753 accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); 754 intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, 755 accum_data_v[j]); 756 } 757 } // End row-block loop. 758 } // End col-block terminal conditional. 759 } 760 761 template <Path path> 762 inline void KernelFloatAvxCommonSingleCol( 763 const KernelParamsFloat<8, 8>& params) { 764 RUY_DCHECK_EQ(params.dst_cols, 1); 765 RUY_DCHECK_EQ(params.last_col, 0); 766 RUY_DCHECK_EQ(params.start_col, 0); 767 768 // As parameters are defined, we need to scale by sizeof(float). 769 const std::int64_t lhs_stride = params.lhs_stride >> 2; 770 // 771 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; 772 // AVX2 float block size = 8. 773 const int end_row = std::min(params.dst_rows, params.last_row + 8); 774 775 float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; 776 const float* adj_lhs_col_ptr = 777 params.lhs_base_ptr - params.start_row * lhs_stride; 778 const float* bias_col_ptr = params.bias; 779 780 const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); 781 const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); 782 783 __m256 accum_data_v; 784 785 const float* rhs_col_ptr = params.rhs_base_ptr; 786 float* dst_col_ptr = adj_dst_col_ptr; 787 788 int row = params.start_row; 789 for (; row <= end_row - 8; row += 8) { 790 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; 791 float* dst_ptr = dst_col_ptr + row; 792 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; 793 794 // Initialize with bias. 795 accum_data_v = _mm256_loadu_ps(bias_ptr); 796 797 const float* lhs_ptr = lhs_col_ptr; 798 const float* rhs_ptr = rhs_col_ptr; 799 int d = 0; 800 for (; d <= params.depth - 4; d += 4) { 801 const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); 802 const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); 803 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0, 804 accum_data_v); 805 const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); 806 const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); 807 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1, 808 accum_data_v); 809 810 const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); 811 const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); 812 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2, 813 accum_data_v); 814 const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); 815 const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); 816 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3, 817 accum_data_v); 818 lhs_ptr += 32; // Loaded 8 * 4 floats. 819 rhs_ptr += 32; 820 } 821 for (; d < params.depth; ++d) { 822 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); 823 const float* rhs_data = rhs_ptr; 824 825 const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); 826 accum_data_v = 827 intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); 828 lhs_ptr += 8; 829 rhs_ptr += 8; 830 } 831 832 accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); 833 accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); 834 _mm256_storeu_ps(dst_ptr, accum_data_v); 835 } // End row-block loop. 836 837 if (row < end_row) { 838 const int residual_rows = end_row - row; 839 RUY_CHECK_GE(residual_rows, 1); 840 RUY_CHECK_LT(residual_rows, 8); 841 842 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; 843 float* dst_ptr = dst_col_ptr + row; 844 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; 845 846 // Initialize with bias. 847 accum_data_v = _mm256_loadu_ps(bias_ptr); 848 849 const float* lhs_ptr = lhs_col_ptr; 850 const float* rhs_ptr = rhs_col_ptr; 851 for (int d = 0; d < params.depth; ++d) { 852 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); 853 const float* rhs_data = rhs_ptr; 854 855 const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); 856 accum_data_v = 857 intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); 858 lhs_ptr += 8; 859 rhs_ptr += 8; 860 } 861 862 accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); 863 accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); 864 intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v); 865 } // End handling of residual rows. 866 } 867 } // namespace ruy 868 #endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM) 869 870 #endif // RUY_RUY_KERNEL_X86_H_ 871