1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ 18 19 // Depending on a build configuration this header provides custom kernel for 20 // Eigen tensor contractions (small matrix multiplication kernel used to 21 // multiple together blocks of the original tensors). 22 // 23 // 1) --define tensorflow_mkldnn_contraction_kernel=1 24 // Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at 25 // runtime and use avx/avx2/fma/avx512 based on cpu status registers 26 // (https://en.wikipedia.org/wiki/CPUID). 27 // 28 // If you use `tensor.contract(other_tensor)` in your code, you must include 29 // this header to get the benefit of custom contraction kernel: 30 // 31 // #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 32 // #include "tensorflow/core/kernels/eigen_contraction_kernel.h" 33 // #endif 34 35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 36 37 // FixedPoint header must be included after Tensor. 38 // clang-format off 39 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" 40 // clang-format on 41 42 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) 43 #include "dnnl.h" 44 #endif 45 46 #include "tensorflow/core/platform/dynamic_annotations.h" 47 48 namespace Eigen { 49 namespace internal { 50 51 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 52 // Returns `true` iff we can use custom contraction kernels. This is a runtime 53 // check, that uses environment variables. 54 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels(); 55 56 // Pack a 2D block of a Tensor expression into contiguous block of memory with 57 // col-major storage order. We do not have access to the underlying Tensor 58 // expression, we only have a DataMapper (TensorContractionInputMapper for 59 // tensor contractions, or blas_data_mapper for plain tensors), that provides a 60 // two-dimensional view into the Tensor expression. 61 // 62 // Default Eigen gemm_pack_rhs and gemm_pack_lhs pack blocks of tensor 63 // expressions into the packed format described in "Anatomy of High-Performance 64 // Matrix Multiplication" paper (1). Eigen::internal::gebp_kernel relies on this 65 // packing format for efficient micro-panel multiplication. 66 // 67 // This simple packing can be used with any '?gemm' function from BLAS 68 // libraries, that work with col-major matrices. 69 // 70 // (1) http://www.cs.utexas.edu/~flame/pubs/GotoTOMS_revision.pdf 71 // 72 // IMPORTANT: `gemm_pack_colmajor_block` always packs the block in column major 73 // order, DataMapperStorageOrder specifies the storage order of the underlying 74 // Tensor expression. 75 template <typename Scalar, typename IndexType, typename DataMapper, 76 int DataMapperStorageOrder> 77 struct gemm_pack_colmajor_block; 78 79 // gemm_pack_colmajor_block for ColMajor storage order. 80 template <typename Scalar, typename IndexType, typename DataMapper> 81 struct gemm_pack_colmajor_block<Scalar, IndexType, DataMapper, 82 /*DataMapperStorageOrder*/ ColMajor> { 83 typedef typename internal::packet_traits<Scalar>::type Packet; 84 typedef typename DataMapper::LinearMapper LinearMapper; 85 86 enum { PacketSize = internal::packet_traits<Scalar>::size }; 87 88 EIGEN_DONT_INLINE 89 void operator()(Scalar* block, const DataMapper& data_mapper, IndexType rows, 90 IndexType cols) { 91 const IndexType unrolled_rows = rows - 4 * PacketSize; 92 const IndexType vectorized_rows = rows - PacketSize; 93 94 for (IndexType col = 0; col < cols; ++col) { 95 LinearMapper lm = data_mapper.getLinearMapper(0, col); 96 97 IndexType row = 0; 98 // Give compiler a strong possibility to unroll the loop. 99 for (; row <= unrolled_rows; row += 4 * PacketSize) { 100 for (IndexType j = 0; j < 4; ++j) { 101 const Packet p = lm.template loadPacket<Packet>(row + j * PacketSize); 102 internal::pstoreu(block + j * PacketSize, p); 103 } 104 block += 4 * PacketSize; 105 } 106 // Process remaining rows with packets. 107 for (; row <= vectorized_rows; row += PacketSize) { 108 const Packet p = lm.template loadPacket<Packet>(row); 109 internal::pstoreu(block, p); 110 block += PacketSize; 111 } 112 // Finalize with coefficients. 113 for (; row < rows; ++row) { 114 *block = lm(row); 115 ++block; 116 } 117 } 118 } 119 }; 120 121 #endif // TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL 122 123 // Enabled by build option: "--define tensorflow_mkldnn_contraction_kernel=1" 124 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) 125 126 template <typename Scalar, typename IndexType, typename OutputMapper, 127 bool ConjugateLhs = false, bool ConjugateRhs = false> 128 struct dnnl_gemm_kernel; 129 130 // dnnl_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm. 131 template <typename IndexType, typename OutputMapper, bool ConjugateLhs, 132 bool ConjugateRhs> 133 struct dnnl_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper, ConjugateLhs, 134 ConjugateRhs> { 135 static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs"); 136 static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs"); 137 138 static constexpr int kComputeStrideFromBlockDimensions = -1; 139 140 using LhsScalar = float; 141 using RhsScalar = float; 142 using ResScalar = float; 143 144 EIGEN_DONT_INLINE 145 void operator()(const OutputMapper& output, const LhsScalar* blockA, 146 const RhsScalar* blockB, const IndexType rows, 147 const IndexType depth, const IndexType cols, float alpha, 148 float beta, int ldA = kComputeStrideFromBlockDimensions, 149 int ldB = kComputeStrideFromBlockDimensions, 150 char transposeA = 'N', char transposeB = 'N') { 151 static const int max_index = (std::numeric_limits<int>::max)(); 152 153 eigen_assert(max_index >= rows); 154 eigen_assert(max_index >= cols); 155 eigen_assert(max_index >= depth); 156 eigen_assert(max_index >= output.stride()); 157 158 const int m = static_cast<int>(rows); 159 const int n = static_cast<int>(cols); 160 const int k = static_cast<int>(depth); 161 162 ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA; 163 ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB; 164 const int ldC = static_cast<int>(output.stride()); 165 166 // DNNL takes row-major matrices. Our packed column-major matrices can be 167 // viewed as a transposed row-major matrix, i.e., 168 // C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T 169 // = B_rowmajor^T * A_rowmajor^T 170 // = B_colmajor * A_colmajor 171 // So we can just swap the input matrices A and B for DNNL. 172 // TODO(penporn): Switch to row-major packing instead. 173 dnnl_status_t st = 174 dnnl_sgemm(transposeB, transposeA, n, m, k, alpha, blockB, ldB, blockA, 175 ldA, beta, const_cast<ResScalar*>(output.data()), ldC); 176 eigen_assert(st == 0); 177 178 #if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER) 179 for (IndexType col = 0; col < cols; ++col) { 180 ResScalar* row_base = &output(0, col); 181 EIGEN_UNUSED_VARIABLE(row_base); // Suppress unused variable error. 182 TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows); 183 } 184 #endif 185 186 // eigen_assert is a no-op in optimized mode so we add these to avoid 187 // compiler's unused-variable errors. 188 EIGEN_UNUSED_VARIABLE(max_index); 189 EIGEN_UNUSED_VARIABLE(st); 190 } 191 }; 192 193 template <typename IndexType, typename OutputMapper, bool ConjugateLhs = false, 194 bool ConjugateRhs = false> 195 struct mkldnn_gemm_s8u8s32_kernel { 196 static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs"); 197 static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs"); 198 199 static constexpr int kComputeStrideFromBlockDimensions = -1; 200 201 using LhsScalar = Eigen::QInt8; 202 using RhsScalar = Eigen::QUInt8; 203 using ResScalar = Eigen::QInt32; 204 205 EIGEN_DONT_INLINE 206 void operator()(const OutputMapper& output, const LhsScalar* blockA, 207 const RhsScalar* blockB, const IndexType rows, 208 const IndexType depth, const IndexType cols, float alpha, 209 float beta, int ldA = kComputeStrideFromBlockDimensions, 210 int ldB = kComputeStrideFromBlockDimensions, 211 char transposeA = 'N', char transposeB = 'N') { 212 static const int max_index = (std::numeric_limits<int>::max)(); 213 214 eigen_assert(max_index >= rows); 215 eigen_assert(max_index >= cols); 216 eigen_assert(max_index >= depth); 217 eigen_assert(max_index >= output.stride()); 218 219 const int m = static_cast<int>(rows); 220 const int n = static_cast<int>(cols); 221 const int k = static_cast<int>(depth); 222 223 ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA; 224 ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB; 225 const int ldC = static_cast<int>(output.stride()); 226 227 // Currently we support only symmetric quantization with zero point at 0. 228 const int8_t ao = 0; 229 const int8_t bo = 0; 230 231 // Don't add any offset to the result C. 232 const char offsetc = 'F'; 233 const int32_t co = 0; 234 235 const auto* A = reinterpret_cast<const int8_t*>(blockA); 236 const auto* B = reinterpret_cast<const uint8_t*>(blockB); 237 auto* C = reinterpret_cast<int32_t*>(const_cast<ResScalar*>(output.data())); 238 239 // DNNL takes row-major matrices. Our packed column-major matrices can be 240 // viewed as a transposed row-major matrix, i.e., C_colmajor = C_rowmajor^T. 241 // C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T 242 // = B_rowmajor^T * A_rowmajor^T 243 // = B_colmajor * A_colmajor 244 // So we can just swap the input matrices A and B for DNNL. 245 // TODO(penporn): Switch to row-major packing instead. 246 dnnl_status_t st = dnnl_gemm_u8s8s32(transposeB, transposeA, offsetc, // 247 n, m, k, // 248 alpha, // 249 B, ldB, bo, // 250 A, ldA, ao, // 251 beta, // 252 C, ldC, &co); 253 eigen_assert(st == 0); 254 255 #if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER) 256 for (IndexType col = 0; col < cols; ++col) { 257 ResScalar* row_base = &output(0, col); 258 EIGEN_UNUSED_VARIABLE(row_base); // Suppress unused variable error. 259 TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows); 260 } 261 #endif 262 263 // eigen_assert is a no-op in optimized mode so we add these to avoid 264 // compiler's unused-variable errors. 265 EIGEN_UNUSED_VARIABLE(max_index); 266 EIGEN_UNUSED_VARIABLE(st); 267 } 268 }; 269 270 // For mkldnn_sgemm having the right dimensions (especially for small matrices) 271 // is more important than fitting all the working set in L1/L2 caches. 272 // TODO(ezhulenev): Do better heuristics. 273 template <typename StorageIndex, int sharding_type> 274 class TensorContractionBlocking<float, float, float, StorageIndex, 275 sharding_type> { 276 // For now mkldnn has only mkldnn_sgemm (gemm for floats). 277 using Scalar = float; 278 279 // Adjust the block sizes to work well with mkldnn kernels. 280 281 // Multiply default choice of block size along M and N dimensions. 282 // TODO(ezhulenev): Explore if this can work in general (kScaleM=2.0 worked 283 // well in some of models). 284 static constexpr float kScaleM = 1.5; 285 static constexpr float kScaleN = 1.0; 286 287 // Mkldnn Avx/Avx2/Avx512 unroll factors are: 8/16/48. 288 static constexpr StorageIndex kUnrollM = 48; 289 290 // Mkldnn Avx/Avx2/Avx512 unroll factors are: 6/6/8. 291 static constexpr StorageIndex kUnrollN = 24; 292 293 public: 294 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, 295 StorageIndex num_threads = 1) 296 : kc_(k), mc_(m), nc_(n) { 297 // 1. Compute block sizes using default Eigen heuristics. 298 if (sharding_type == ShardByCol) { 299 computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, mc_, nc_, 300 num_threads); 301 } else { 302 computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, nc_, mc_, 303 num_threads); 304 } 305 306 // If dimensions do not pass basic sanity checks return immediately. 307 if (kc_ <= 0 || mc_ <= 0 || nc_ <= 0) return; 308 309 // If we are using default Eigen gebp kernel there is no need to adjust the 310 // block sizes for DNNL. 311 if (!UseCustomContractionKernels()) return; 312 313 // 2. And refine them to work well with mkldnn sgemm. 314 mc_ = (std::min)( 315 m, Eigen::divup(static_cast<StorageIndex>(mc_ * kScaleM), kUnrollM) * 316 kUnrollM); 317 nc_ = (std::min)( 318 n, Eigen::divup(static_cast<StorageIndex>(nc_ * kScaleN), kUnrollN) * 319 kUnrollN); 320 321 // We split Kth dimensions in roughly equal slices. 322 StorageIndex target_k_slices = 323 (std::max)(StorageIndex(1), Eigen::divup(k, kc_)); 324 StorageIndex packet_size = internal::packet_traits<Scalar>::size; 325 if (packet_size < 8) packet_size = 8; 326 StorageIndex target_bk = 327 Eigen::divup(k / target_k_slices, packet_size) * packet_size; 328 kc_ = (std::min)(k, target_bk); 329 } 330 331 EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } 332 EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; } 333 EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; } 334 335 private: 336 StorageIndex kc_; 337 StorageIndex mc_; 338 StorageIndex nc_; 339 }; 340 341 template <typename StorageIndex, int sharding_type> 342 class TensorContractionBlocking<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8, 343 StorageIndex, sharding_type> { 344 // TODO(ezhulenev): Define proper gebp_traits in Eigen for quantized types? 345 346 // Default Eigen block heuristics for `QInt8xQUInt8 -> QInt32` are wrong. 347 // Mostly because gebp_traits are not correctly defined. But we know that we 348 // are going to use s8u8s32_gemm from DNNL, so we use float heuristics, and 349 // adjust them to work well with DNNL. 350 using LhsScalar = Eigen::QInt8; 351 using RhsScalar = Eigen::QUInt8; 352 using ResScalar = Eigen::QInt32; 353 354 // Multiply default choice of block size along M, N and K dimensions. 355 static constexpr float kScaleM = 1.5; 356 static constexpr float kScaleN = 1.5; 357 static constexpr float kScaleK = 1.5; 358 359 public: 360 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, 361 StorageIndex num_threads = 1) 362 : kc_(k), mc_(m), nc_(n) { 363 // Each dimension is a multiple of 32 (fits into _m256i). 364 mc_ = (std::min)(m, static_cast<StorageIndex>(192)); 365 nc_ = (std::min)(n, static_cast<StorageIndex>(288)); 366 kc_ = (std::min)(k, static_cast<StorageIndex>(320)); 367 } 368 369 EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } 370 EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; } 371 EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; } 372 373 private: 374 StorageIndex kc_; 375 StorageIndex mc_; 376 StorageIndex nc_; 377 }; 378 379 // If the Lhs or Rhs Tensor expressions are already evaluated and have access to 380 // raw data, we can skip packing step and setup pointers and a stride to the 381 // underlying memory buffer and pass them directly to Gemm. 382 template <typename Scalar, typename StorageIndex> 383 struct ColMajorBlock { 384 bool is_direct_access; 385 386 // Valid iff `is_direct_access == false` 387 Scalar* packed_data; 388 389 // Valid iff `is_direct_access == true` 390 Scalar* raw_data; 391 StorageIndex stride; 392 char transpose; 393 }; 394 395 template <typename DataMapper> 396 struct DirectColMajorAccess { 397 enum { value = false }; 398 399 template <typename Scalar, typename StorageIndex> 400 static bool block(const typename DataMapper::SubMapper& data_mapper, 401 const StorageIndex rows, const StorageIndex cols, 402 const StorageIndex num_kernels, 403 ColMajorBlock<Scalar, StorageIndex>* block) { 404 eigen_assert(false && "Not implemented"); 405 return false; 406 } 407 }; 408 409 // If we have an access to raw memory of the contraction input, we can safely 410 // skip packing if: 411 // (1) Packing is a no-op. 412 // (2) Packed block will be used just once. 413 // 414 // If a packed block is used many times, it's more efficient to pack it into 415 // contiguous block of memory to reduce pressure on TLB. 416 // 417 // TODO(ezhulenev): Add support for more tensor expressions that matters. 418 #define REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_EXPR) \ 419 template <typename Scalar, typename StorageIndex, int Side, typename Device, \ 420 typename nocontract_t, typename contract_t, int packet_size, \ 421 int Alignment> \ 422 struct DirectColMajorAccess<TensorContractionInputMapper< \ 423 Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>, \ 424 nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \ 425 /*inner_dim_reordered=*/false, Alignment>> { \ 426 enum { value = true }; \ 427 \ 428 using DataMapper = TensorContractionInputMapper< \ 429 Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>, \ 430 nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \ 431 /*inner_dim_reordered=*/false, Alignment>; \ 432 \ 433 static bool block(const typename DataMapper::SubMapper& data_mapper, \ 434 const StorageIndex rows, const StorageIndex cols, \ 435 const StorageIndex num_kernels, \ 436 ColMajorBlock<Scalar, StorageIndex>* block) { \ 437 static_assert(DataMapper::DirectOffsets == true, \ 438 "DataMapper must support direct offsets"); \ 439 \ 440 const StorageIndex vert_offset = data_mapper.vert_offset(); \ 441 const StorageIndex horiz_offset = data_mapper.horiz_offset(); \ 442 const StorageIndex stride = \ 443 Side == Lhs ? data_mapper.base_mapper().stride() \ 444 : data_mapper.base_mapper().nocontract_strides()[0]; \ 445 const Scalar* data = data_mapper.base_mapper().tensor().data(); \ 446 data = Side == Lhs ? data : data + vert_offset + horiz_offset * stride; \ 447 \ 448 const bool is_no_op_packing = stride == rows; \ 449 const StorageIndex addressable_mem = (stride * cols * sizeof(Scalar)); \ 450 const bool use_direct_access = \ 451 is_no_op_packing || num_kernels == 1 /* used once */ || \ 452 ((num_kernels == 2) && \ 453 (addressable_mem < (256 << 10) /* 256 kb */)); \ 454 \ 455 if (use_direct_access) { \ 456 block->is_direct_access = true; \ 457 block->raw_data = const_cast<Scalar*>(data); \ 458 block->stride = stride; \ 459 block->transpose = 'N'; \ 460 return true; \ 461 } \ 462 return false; \ 463 } \ 464 } 465 466 #define SIMPLE_TENSOR const Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex> 467 468 #define TENSOR_MAP_ROWMAJOR \ 469 const TensorMap<Tensor<const Scalar, 2, Eigen::RowMajor, StorageIndex>, \ 470 Eigen::Aligned> 471 472 #define TENSOR_MAP_COLMAJOR \ 473 const TensorMap<Tensor<const Scalar, 2, Eigen::ColMajor, StorageIndex>, \ 474 Eigen::Aligned> 475 476 #define TENSOR_MAP_CONST_ROWMAJOR \ 477 const TensorMap<Tensor<Scalar, 2, Eigen::RowMajor, StorageIndex>, \ 478 Eigen::Aligned> 479 480 #define TENSOR_MAP_CONST_COLMAJOR \ 481 const TensorMap<Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex>, \ 482 Eigen::Aligned> 483 484 // This is reshaped convolution filter from `eigen_spatial_convolutions.h`. 485 #define TENSOR_RESHAPE \ 486 const TensorReshapingOp< \ 487 const Eigen::DSizes<StorageIndex, 2>, \ 488 const TensorMap<Tensor<const Scalar, 4, Eigen::RowMajor, StorageIndex>, \ 489 Eigen::Aligned>> 490 491 REGISTER_DIRECT_COL_MAJOR_ACCESS(SIMPLE_TENSOR); 492 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_ROWMAJOR); 493 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_COLMAJOR); 494 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_ROWMAJOR); 495 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_COLMAJOR); 496 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_RESHAPE); 497 498 #undef SIMPLE_TENSOR 499 #undef TENSOR_MAP_ROWMAJOR 500 #undef TENSOR_MAP_COLMAJOR 501 #undef TENSOR_MAP_CONST_ROWMAJOR 502 #undef TENSOR_MAP_CONST_COLMAJOR 503 #undef TENSOR_RESHAPE 504 #undef REGISTER_DIRECT_COL_MAJOR_ACCESS 505 506 template <typename ResScalar, typename LhsScalar, typename RhsScalar, 507 typename StorageIndex, typename OutputMapper> 508 struct GemmKernelProvider { 509 enum { Defined = 0 }; 510 using GemmKernel = void; 511 }; 512 513 template <typename StorageIndex, typename OutputMapper> 514 struct GemmKernelProvider<float, float, float, StorageIndex, OutputMapper> { 515 enum { Defined = 1 }; 516 using GemmKernel = dnnl_gemm_kernel<float, StorageIndex, OutputMapper>; 517 }; 518 519 template <typename StorageIndex, typename OutputMapper> 520 struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8, 521 StorageIndex, OutputMapper> { 522 enum { Defined = 1 }; 523 using GemmKernel = mkldnn_gemm_s8u8s32_kernel<StorageIndex, OutputMapper>; 524 }; 525 526 // NOTE: 'std::enable_if' doesn't work for template specializations. See 527 // "default template argument in a class template partial specialization". 528 529 // Tensor contraction kernel that can fallback on Eigen gebp_kernel at runtime. 530 #define REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK( \ 531 RES_SCALAR, LHS_SCALAR, RHS_SCALAR) \ 532 \ 533 template <typename StorageIndex, typename OutputMapper, typename LhsMapper, \ 534 typename RhsMapper> \ 535 struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR, \ 536 StorageIndex, OutputMapper, LhsMapper, \ 537 RhsMapper> { \ 538 TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \ 539 StorageIndex bm, StorageIndex bk, StorageIndex bn) \ 540 : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \ 541 \ 542 enum { HasBeta = true }; \ 543 \ 544 using ResScalar = RES_SCALAR; \ 545 using LhsScalar = LHS_SCALAR; \ 546 using RhsScalar = RHS_SCALAR; \ 547 \ 548 using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>; \ 549 \ 550 using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>; \ 551 using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>; \ 552 \ 553 using DirectLhsAccess = DirectColMajorAccess<LhsMapper>; \ 554 using DirectRhsAccess = DirectColMajorAccess<RhsMapper>; \ 555 \ 556 /* Packed Lhs/Rhs block memory allocator.*/ \ 557 typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> \ 558 BlockMemAllocator; \ 559 typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \ 560 \ 561 using LhsPacker = \ 562 gemm_pack_colmajor_block<LhsScalar, StorageIndex, \ 563 typename LhsMapper::SubMapper, ColMajor>; \ 564 using RhsPacker = \ 565 gemm_pack_colmajor_block<RhsScalar, StorageIndex, \ 566 typename RhsMapper::SubMapper, ColMajor>; \ 567 \ 568 using GemmKernelProviderType = \ 569 GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex, \ 570 OutputMapper>; \ 571 static_assert( \ 572 GemmKernelProviderType::Defined, \ 573 "Custom GEMM kernel is not registered for given scalar types"); \ 574 using GemmKernel = typename GemmKernelProviderType::GemmKernel; \ 575 \ 576 /* Fallback on default Eigen pack and GEBP kernel if custom contraction */ \ 577 /* kernels disabled at runtime. */ \ 578 using EigenLhsPacker = \ 579 gemm_pack_lhs<LhsScalar, StorageIndex, typename LhsMapper::SubMapper, \ 580 Traits::mr, Traits::LhsProgress, \ 581 typename Traits::LhsPacket4Packing, ColMajor>; \ 582 using EigenRhsPacker = \ 583 gemm_pack_rhs<RhsScalar, StorageIndex, typename RhsMapper::SubMapper, \ 584 Traits::nr, ColMajor>; \ 585 using GebpKernel = \ 586 gebp_kernel<LhsScalar, RhsScalar, StorageIndex, OutputMapper, \ 587 Traits::mr, Traits::nr, /*ConjugateLhs*/ false, \ 588 /*ConjugateRhs*/ false>; \ 589 \ 590 template <typename Device> \ 591 EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \ 592 RhsBlock* rhs_block) { \ 593 return BlockMemAllocator::allocate( \ 594 d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \ 595 } \ 596 \ 597 template <typename Device> \ 598 EIGEN_DEVICE_FUNC BlockMemHandle \ 599 allocateSlices(Device& d, const int num_lhs, const int num_rhs, \ 600 const int num_slices, std::vector<LhsBlock>* lhs_blocks, \ 601 std::vector<RhsBlock>* rhs_blocks) { \ 602 eigen_assert(num_slices > 0); \ 603 std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices); \ 604 std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices); \ 605 \ 606 BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \ 607 d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \ 608 rhs_mem.data()); \ 609 \ 610 for (Index x = 0; x < num_slices; x++) { \ 611 if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \ 612 for (Index m = 0; m < num_lhs; m++) { \ 613 lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \ 614 } \ 615 if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \ 616 for (Index n = 0; n < num_rhs; n++) { \ 617 rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \ 618 } \ 619 } \ 620 \ 621 return block_mem; \ 622 } \ 623 \ 624 template <typename Device> \ 625 EIGEN_DEVICE_FUNC static void deallocate(Device& d, \ 626 BlockMemHandle handle) { \ 627 BlockMemAllocator::deallocate(d, handle); \ 628 } \ 629 \ 630 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \ 631 LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \ 632 const StorageIndex depth, const StorageIndex rows) { \ 633 if (UseCustomContractionKernels()) { \ 634 const bool is_direct_access = \ 635 DirectLhsAccess::value && \ 636 DirectLhsAccess::block(data_mapper, rows, depth, \ 637 bn > 0 ? divup(n, bn) : 0, lhsBlock); \ 638 \ 639 if (!is_direct_access) { \ 640 lhsBlock->is_direct_access = false; \ 641 LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \ 642 } \ 643 } else { \ 644 lhsBlock->is_direct_access = false; \ 645 EigenLhsPacker()(lhsBlock->packed_data, data_mapper, depth, rows, \ 646 /*stride*/ 0, /*offset*/ 0); \ 647 } \ 648 } \ 649 \ 650 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \ 651 RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \ 652 const StorageIndex depth, const StorageIndex cols) { \ 653 if (UseCustomContractionKernels()) { \ 654 const bool is_direct_access = \ 655 DirectRhsAccess::value && \ 656 DirectRhsAccess::block(data_mapper, depth, cols, \ 657 bm > 0 ? divup(m, bm) : 0, rhsBlock); \ 658 \ 659 if (!is_direct_access) { \ 660 rhsBlock->is_direct_access = false; \ 661 RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ 662 } \ 663 } else { \ 664 rhsBlock->is_direct_access = false; \ 665 EigenRhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ 666 } \ 667 } \ 668 \ 669 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \ 670 const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \ 671 const RhsBlock& rhsBlock, const StorageIndex rows, \ 672 const StorageIndex depth, const StorageIndex cols, const float alpha, \ 673 const float beta) { \ 674 if (UseCustomContractionKernels()) { \ 675 if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \ 676 (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \ 677 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \ 678 rows, depth, cols, alpha, beta, \ 679 /*ldA=*/lhsBlock.stride, /*ldB=*/rhsBlock.stride, \ 680 /*transposeA=*/lhsBlock.transpose, \ 681 /*transposeB=*/rhsBlock.transpose); \ 682 \ 683 } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \ 684 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \ 685 rows, depth, cols, alpha, beta, \ 686 /*ldA=*/lhsBlock.stride, \ 687 /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \ 688 /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \ 689 \ 690 } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \ 691 GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \ 692 rows, depth, cols, alpha, beta, \ 693 /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \ 694 /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \ 695 /*transposeB=*/rhsBlock.transpose); \ 696 \ 697 } else { \ 698 GemmKernel()(output_mapper, lhsBlock.packed_data, \ 699 rhsBlock.packed_data, rows, depth, cols, alpha, beta); \ 700 } \ 701 } else { \ 702 /* Gebp kernel does not support beta, so we have to clear memory in */ \ 703 /* the output mapper manually. */ \ 704 /* WARNING(ezhulenev): This is optimized into a memset in a loop, */ \ 705 /* could be much slower for small matrices. Currently this code */ \ 706 /* path used only for testing, and performance does not matter. */ \ 707 if (beta == 0.0) { \ 708 for (StorageIndex col = 0; col < cols; ++col) { \ 709 ResScalar* output_base = &output_mapper(0, col); \ 710 typedef Array<ResScalar, Dynamic, 1> OutputRow; \ 711 typedef Map<OutputRow, 0, InnerStride<1>> OutputRowMap; \ 712 OutputRowMap(output_base, rows).setZero(); \ 713 } \ 714 } \ 715 \ 716 GebpKernel()( \ 717 output_mapper, lhsBlock.packed_data, rhsBlock.packed_data, rows, \ 718 depth, cols, alpha, \ 719 /*strideA*/ GemmKernel::kComputeStrideFromBlockDimensions, \ 720 /*strideB*/ GemmKernel::kComputeStrideFromBlockDimensions, \ 721 /*offsetA*/ 0, /*offsetB*/ 0); \ 722 } \ 723 } \ 724 \ 725 private: \ 726 /* These are dimensions of the original Tensors, and selected block */ \ 727 /* sizes. The actual block sizes passed to all function above might be */ \ 728 /* smaller because of the partial blocks at the end. */ \ 729 const StorageIndex m; \ 730 const StorageIndex k; \ 731 const StorageIndex n; \ 732 const StorageIndex bm; \ 733 const StorageIndex bk; \ 734 const StorageIndex bn; \ 735 } 736 737 // Tensor contraction kernel that do not fallback on Eigen. Currently not all 738 // data types are supported by Eigen data packing and default gebp_kernel. 739 #define REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(RES_SCALAR, LHS_SCALAR, \ 740 RHS_SCALAR) \ 741 \ 742 template <typename StorageIndex, typename OutputMapper, typename LhsMapper, \ 743 typename RhsMapper> \ 744 struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR, \ 745 StorageIndex, OutputMapper, LhsMapper, \ 746 RhsMapper> { \ 747 TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \ 748 StorageIndex bm, StorageIndex bk, StorageIndex bn) \ 749 : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} \ 750 \ 751 enum { HasBeta = true }; \ 752 \ 753 using ResScalar = RES_SCALAR; \ 754 using LhsScalar = LHS_SCALAR; \ 755 using RhsScalar = RHS_SCALAR; \ 756 \ 757 using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>; \ 758 \ 759 using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>; \ 760 using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>; \ 761 \ 762 using DirectLhsAccess = DirectColMajorAccess<LhsMapper>; \ 763 using DirectRhsAccess = DirectColMajorAccess<RhsMapper>; \ 764 \ 765 /* Packed Lhs/Rhs block memory allocator.*/ \ 766 typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> \ 767 BlockMemAllocator; \ 768 typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \ 769 \ 770 using LhsPacker = \ 771 gemm_pack_colmajor_block<LhsScalar, StorageIndex, \ 772 typename LhsMapper::SubMapper, ColMajor>; \ 773 using RhsPacker = \ 774 gemm_pack_colmajor_block<RhsScalar, StorageIndex, \ 775 typename RhsMapper::SubMapper, ColMajor>; \ 776 \ 777 using GemmKernelProviderType = \ 778 GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex, \ 779 OutputMapper>; \ 780 static_assert( \ 781 GemmKernelProviderType::Defined, \ 782 "Custom GEMM kernel is not registered for given scalar types"); \ 783 using GemmKernel = typename GemmKernelProviderType::GemmKernel; \ 784 \ 785 template <typename Device> \ 786 EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \ 787 RhsBlock* rhs_block) { \ 788 return BlockMemAllocator::allocate( \ 789 d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \ 790 } \ 791 \ 792 template <typename Device> \ 793 EIGEN_DEVICE_FUNC BlockMemHandle \ 794 allocateSlices(Device& d, const int num_lhs, const int num_rhs, \ 795 const int num_slices, std::vector<LhsBlock>* lhs_blocks, \ 796 std::vector<RhsBlock>* rhs_blocks) { \ 797 eigen_assert(num_slices > 0); \ 798 std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices); \ 799 std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices); \ 800 \ 801 BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \ 802 d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \ 803 rhs_mem.data()); \ 804 \ 805 for (Index x = 0; x < num_slices; x++) { \ 806 if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \ 807 for (Index m = 0; m < num_lhs; m++) { \ 808 lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \ 809 } \ 810 if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \ 811 for (Index n = 0; n < num_rhs; n++) { \ 812 rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \ 813 } \ 814 } \ 815 \ 816 return block_mem; \ 817 } \ 818 \ 819 template <typename Device> \ 820 EIGEN_DEVICE_FUNC static void deallocate(Device& d, \ 821 BlockMemHandle handle) { \ 822 BlockMemAllocator::deallocate(d, handle); \ 823 } \ 824 \ 825 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \ 826 LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \ 827 const StorageIndex depth, const StorageIndex rows) { \ 828 const bool is_direct_access = \ 829 DirectLhsAccess::value && \ 830 DirectLhsAccess::block(data_mapper, rows, depth, \ 831 bn > 0 ? divup(n, bn) : 0, lhsBlock); \ 832 \ 833 if (!is_direct_access) { \ 834 lhsBlock->is_direct_access = false; \ 835 LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \ 836 } \ 837 } \ 838 \ 839 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \ 840 RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \ 841 const StorageIndex depth, const StorageIndex cols) { \ 842 const bool is_direct_access = \ 843 DirectRhsAccess::value && \ 844 DirectRhsAccess::block(data_mapper, depth, cols, \ 845 bm > 0 ? divup(m, bm) : 0, rhsBlock); \ 846 \ 847 if (!is_direct_access) { \ 848 rhsBlock->is_direct_access = false; \ 849 RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ 850 } \ 851 } \ 852 \ 853 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \ 854 const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \ 855 const RhsBlock& rhsBlock, const StorageIndex rows, \ 856 const StorageIndex depth, const StorageIndex cols, const float alpha, \ 857 const float beta) { \ 858 if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \ 859 (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \ 860 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \ 861 rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride, \ 862 /*ldB=*/rhsBlock.stride, \ 863 /*transposeA=*/lhsBlock.transpose, \ 864 /*transposeB=*/rhsBlock.transpose); \ 865 \ 866 } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \ 867 GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \ 868 rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride, \ 869 /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \ 870 /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \ 871 \ 872 } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \ 873 GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \ 874 rows, depth, cols, alpha, beta, \ 875 /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \ 876 /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \ 877 /*transposeB=*/rhsBlock.transpose); \ 878 \ 879 } else { \ 880 GemmKernel()(output_mapper, lhsBlock.packed_data, \ 881 rhsBlock.packed_data, rows, depth, cols, alpha, beta); \ 882 } \ 883 } \ 884 \ 885 private: \ 886 /* These are dimensions of the original Tensors, and selected block */ \ 887 /* sizes. The actual block sizes passed to all function above might be */ \ 888 /* smaller because of the partial blocks at the end. */ \ 889 const StorageIndex m; \ 890 const StorageIndex k; \ 891 const StorageIndex n; \ 892 const StorageIndex bm; \ 893 const StorageIndex bk; \ 894 const StorageIndex bn; \ 895 } 896 897 REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(float, float, float); 898 REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(Eigen::QInt32, Eigen::QInt8, 899 Eigen::QUInt8); 900 901 #undef REGISTER_TENSOR_CONTRACTION_KERNEL 902 903 #endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) 904 905 } // namespace internal 906 } // namespace Eigen 907 908 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_ 909