• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
18 
19 // Depending on a build configuration this header provides custom kernel for
20 // Eigen tensor contractions (small matrix multiplication kernel used to
21 // multiple together blocks of the original tensors).
22 //
23 // 1) --define tensorflow_mkldnn_contraction_kernel=1
24 //    Use Mkldnn single threaded sgemm. The mkldnn kernels are generated at
25 //    runtime and use avx/avx2/fma/avx512 based on cpu status registers
26 //    (https://en.wikipedia.org/wiki/CPUID).
27 //
28 // If you use `tensor.contract(other_tensor)` in your code, you must include
29 // this header to get the benefit of custom contraction kernel:
30 //
31 //   #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
32 //   #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
33 //   #endif
34 
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 
37 // FixedPoint header must be included after Tensor.
38 // clang-format off
39 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
40 // clang-format on
41 
42 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
43 #include "dnnl.h"
44 #endif
45 
46 #include "tensorflow/core/platform/dynamic_annotations.h"
47 
48 namespace Eigen {
49 namespace internal {
50 
51 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
52 // Returns `true` iff we can use custom contraction kernels. This is a runtime
53 // check, that uses environment variables.
54 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels();
55 
56 // Pack a 2D block of a Tensor expression into contiguous block of memory with
57 // col-major storage order. We do not have access to the underlying Tensor
58 // expression, we only have a DataMapper (TensorContractionInputMapper for
59 // tensor contractions, or blas_data_mapper for plain tensors), that provides a
60 // two-dimensional view into the Tensor expression.
61 //
62 // Default Eigen gemm_pack_rhs and gemm_pack_lhs pack blocks of tensor
63 // expressions into the packed format described in "Anatomy of High-Performance
64 // Matrix Multiplication" paper (1). Eigen::internal::gebp_kernel relies on this
65 // packing format for efficient micro-panel multiplication.
66 //
67 // This simple packing can be used with any '?gemm' function from BLAS
68 // libraries, that work with col-major matrices.
69 //
70 // (1) http://www.cs.utexas.edu/~flame/pubs/GotoTOMS_revision.pdf
71 //
72 // IMPORTANT: `gemm_pack_colmajor_block` always packs the block in column major
73 // order, DataMapperStorageOrder specifies the storage order of the underlying
74 // Tensor expression.
75 template <typename Scalar, typename IndexType, typename DataMapper,
76           int DataMapperStorageOrder>
77 struct gemm_pack_colmajor_block;
78 
79 // gemm_pack_colmajor_block for ColMajor storage order.
80 template <typename Scalar, typename IndexType, typename DataMapper>
81 struct gemm_pack_colmajor_block<Scalar, IndexType, DataMapper,
82                                 /*DataMapperStorageOrder*/ ColMajor> {
83   typedef typename internal::packet_traits<Scalar>::type Packet;
84   typedef typename DataMapper::LinearMapper LinearMapper;
85 
86   enum { PacketSize = internal::packet_traits<Scalar>::size };
87 
88   EIGEN_DONT_INLINE
89   void operator()(Scalar* block, const DataMapper& data_mapper, IndexType rows,
90                   IndexType cols) {
91     const IndexType unrolled_rows = rows - 4 * PacketSize;
92     const IndexType vectorized_rows = rows - PacketSize;
93 
94     for (IndexType col = 0; col < cols; ++col) {
95       LinearMapper lm = data_mapper.getLinearMapper(0, col);
96 
97       IndexType row = 0;
98       // Give compiler a strong possibility to unroll the loop.
99       for (; row <= unrolled_rows; row += 4 * PacketSize) {
100         for (IndexType j = 0; j < 4; ++j) {
101           const Packet p = lm.template loadPacket<Packet>(row + j * PacketSize);
102           internal::pstoreu(block + j * PacketSize, p);
103         }
104         block += 4 * PacketSize;
105       }
106       // Process remaining rows with packets.
107       for (; row <= vectorized_rows; row += PacketSize) {
108         const Packet p = lm.template loadPacket<Packet>(row);
109         internal::pstoreu(block, p);
110         block += PacketSize;
111       }
112       // Finalize with coefficients.
113       for (; row < rows; ++row) {
114         *block = lm(row);
115         ++block;
116       }
117     }
118   }
119 };
120 
121 #endif  // TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL
122 
123 // Enabled by build option: "--define tensorflow_mkldnn_contraction_kernel=1"
124 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
125 
126 template <typename Scalar, typename IndexType, typename OutputMapper,
127           bool ConjugateLhs = false, bool ConjugateRhs = false>
128 struct dnnl_gemm_kernel;
129 
130 // dnnl_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm.
131 template <typename IndexType, typename OutputMapper, bool ConjugateLhs,
132           bool ConjugateRhs>
133 struct dnnl_gemm_kernel</*Scalar*/ float, IndexType, OutputMapper, ConjugateLhs,
134                         ConjugateRhs> {
135   static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs");
136   static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs");
137 
138   static constexpr int kComputeStrideFromBlockDimensions = -1;
139 
140   using LhsScalar = float;
141   using RhsScalar = float;
142   using ResScalar = float;
143 
144   EIGEN_DONT_INLINE
145   void operator()(const OutputMapper& output, const LhsScalar* blockA,
146                   const RhsScalar* blockB, const IndexType rows,
147                   const IndexType depth, const IndexType cols, float alpha,
148                   float beta, int ldA = kComputeStrideFromBlockDimensions,
149                   int ldB = kComputeStrideFromBlockDimensions,
150                   char transposeA = 'N', char transposeB = 'N') {
151     static const int max_index = (std::numeric_limits<int>::max)();
152 
153     eigen_assert(max_index >= rows);
154     eigen_assert(max_index >= cols);
155     eigen_assert(max_index >= depth);
156     eigen_assert(max_index >= output.stride());
157 
158     const int m = static_cast<int>(rows);
159     const int n = static_cast<int>(cols);
160     const int k = static_cast<int>(depth);
161 
162     ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA;
163     ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB;
164     const int ldC = static_cast<int>(output.stride());
165 
166     // DNNL takes row-major matrices. Our packed column-major matrices can be
167     // viewed as a transposed row-major matrix, i.e.,
168     //   C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T
169     //                             = B_rowmajor^T * A_rowmajor^T
170     //                             = B_colmajor * A_colmajor
171     // So we can just swap the input matrices A and B for DNNL.
172     // TODO(penporn): Switch to row-major packing instead.
173     dnnl_status_t st =
174         dnnl_sgemm(transposeB, transposeA, n, m, k, alpha, blockB, ldB, blockA,
175                    ldA, beta, const_cast<ResScalar*>(output.data()), ldC);
176     eigen_assert(st == 0);
177 
178 #if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER)
179     for (IndexType col = 0; col < cols; ++col) {
180       ResScalar* row_base = &output(0, col);
181       EIGEN_UNUSED_VARIABLE(row_base);  // Suppress unused variable error.
182       TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows);
183     }
184 #endif
185 
186     // eigen_assert is a no-op in optimized mode so we add these to avoid
187     // compiler's unused-variable errors.
188     EIGEN_UNUSED_VARIABLE(max_index);
189     EIGEN_UNUSED_VARIABLE(st);
190   }
191 };
192 
193 template <typename IndexType, typename OutputMapper, bool ConjugateLhs = false,
194           bool ConjugateRhs = false>
195 struct mkldnn_gemm_s8u8s32_kernel {
196   static_assert(!ConjugateLhs, "DNNL kernel doesn't support ConjugateLhs");
197   static_assert(!ConjugateRhs, "DNNL kernel doesn't support ConjugateRhs");
198 
199   static constexpr int kComputeStrideFromBlockDimensions = -1;
200 
201   using LhsScalar = Eigen::QInt8;
202   using RhsScalar = Eigen::QUInt8;
203   using ResScalar = Eigen::QInt32;
204 
205   EIGEN_DONT_INLINE
206   void operator()(const OutputMapper& output, const LhsScalar* blockA,
207                   const RhsScalar* blockB, const IndexType rows,
208                   const IndexType depth, const IndexType cols, float alpha,
209                   float beta, int ldA = kComputeStrideFromBlockDimensions,
210                   int ldB = kComputeStrideFromBlockDimensions,
211                   char transposeA = 'N', char transposeB = 'N') {
212     static const int max_index = (std::numeric_limits<int>::max)();
213 
214     eigen_assert(max_index >= rows);
215     eigen_assert(max_index >= cols);
216     eigen_assert(max_index >= depth);
217     eigen_assert(max_index >= output.stride());
218 
219     const int m = static_cast<int>(rows);
220     const int n = static_cast<int>(cols);
221     const int k = static_cast<int>(depth);
222 
223     ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA;
224     ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB;
225     const int ldC = static_cast<int>(output.stride());
226 
227     // Currently we support only symmetric quantization with zero point at 0.
228     const int8_t ao = 0;
229     const int8_t bo = 0;
230 
231     // Don't add any offset to the result C.
232     const char offsetc = 'F';
233     const int32_t co = 0;
234 
235     const auto* A = reinterpret_cast<const int8_t*>(blockA);
236     const auto* B = reinterpret_cast<const uint8_t*>(blockB);
237     auto* C = reinterpret_cast<int32_t*>(const_cast<ResScalar*>(output.data()));
238 
239     // DNNL takes row-major matrices. Our packed column-major matrices can be
240     // viewed as a transposed row-major matrix, i.e., C_colmajor = C_rowmajor^T.
241     // C_colmajor = C_rowmajor^T = (A_rowmajor * B_rowmajor)^T
242     //                           = B_rowmajor^T * A_rowmajor^T
243     //                           = B_colmajor * A_colmajor
244     // So we can just swap the input matrices A and B for DNNL.
245     // TODO(penporn): Switch to row-major packing instead.
246     dnnl_status_t st = dnnl_gemm_u8s8s32(transposeB, transposeA, offsetc,  //
247                                          n, m, k,                          //
248                                          alpha,                            //
249                                          B, ldB, bo,                       //
250                                          A, ldA, ao,                       //
251                                          beta,                             //
252                                          C, ldC, &co);
253     eigen_assert(st == 0);
254 
255 #if DYNAMIC_ANNOTATIONS_ENABLED == 1 || defined(MEMORY_SANITIZER)
256     for (IndexType col = 0; col < cols; ++col) {
257       ResScalar* row_base = &output(0, col);
258       EIGEN_UNUSED_VARIABLE(row_base);  // Suppress unused variable error.
259       TF_ANNOTATE_MEMORY_IS_INITIALIZED(row_base, sizeof(ResScalar) * rows);
260     }
261 #endif
262 
263     // eigen_assert is a no-op in optimized mode so we add these to avoid
264     // compiler's unused-variable errors.
265     EIGEN_UNUSED_VARIABLE(max_index);
266     EIGEN_UNUSED_VARIABLE(st);
267   }
268 };
269 
270 // For mkldnn_sgemm having the right dimensions (especially for small matrices)
271 // is more important than fitting all the working set in L1/L2 caches.
272 // TODO(ezhulenev): Do better heuristics.
273 template <typename StorageIndex, int sharding_type>
274 class TensorContractionBlocking<float, float, float, StorageIndex,
275                                 sharding_type> {
276   // For now mkldnn has only mkldnn_sgemm (gemm for floats).
277   using Scalar = float;
278 
279   // Adjust the block sizes to work well with mkldnn kernels.
280 
281   // Multiply default choice of block size along M and N dimensions.
282   // TODO(ezhulenev): Explore if this can work in general (kScaleM=2.0 worked
283   // well in some of models).
284   static constexpr float kScaleM = 1.5;
285   static constexpr float kScaleN = 1.0;
286 
287   // Mkldnn Avx/Avx2/Avx512 unroll factors are: 8/16/48.
288   static constexpr StorageIndex kUnrollM = 48;
289 
290   // Mkldnn Avx/Avx2/Avx512 unroll factors are: 6/6/8.
291   static constexpr StorageIndex kUnrollN = 24;
292 
293  public:
294   TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n,
295                             StorageIndex num_threads = 1)
296       : kc_(k), mc_(m), nc_(n) {
297     // 1. Compute block sizes using default Eigen heuristics.
298     if (sharding_type == ShardByCol) {
299       computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, mc_, nc_,
300                                                      num_threads);
301     } else {
302       computeProductBlockingSizes<Scalar, Scalar, 1>(kc_, nc_, mc_,
303                                                      num_threads);
304     }
305 
306     // If dimensions do not pass basic sanity checks return immediately.
307     if (kc_ <= 0 || mc_ <= 0 || nc_ <= 0) return;
308 
309     // If we are using default Eigen gebp kernel there is no need to adjust the
310     // block sizes for DNNL.
311     if (!UseCustomContractionKernels()) return;
312 
313     // 2. And refine them to work well with mkldnn sgemm.
314     mc_ = (std::min)(
315         m, Eigen::divup(static_cast<StorageIndex>(mc_ * kScaleM), kUnrollM) *
316                kUnrollM);
317     nc_ = (std::min)(
318         n, Eigen::divup(static_cast<StorageIndex>(nc_ * kScaleN), kUnrollN) *
319                kUnrollN);
320 
321     // We split Kth dimensions in roughly equal slices.
322     StorageIndex target_k_slices =
323         (std::max)(StorageIndex(1), Eigen::divup(k, kc_));
324     StorageIndex packet_size = internal::packet_traits<Scalar>::size;
325     if (packet_size < 8) packet_size = 8;
326     StorageIndex target_bk =
327         Eigen::divup(k / target_k_slices, packet_size) * packet_size;
328     kc_ = (std::min)(k, target_bk);
329   }
330 
331   EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }
332   EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; }
333   EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; }
334 
335  private:
336   StorageIndex kc_;
337   StorageIndex mc_;
338   StorageIndex nc_;
339 };
340 
341 template <typename StorageIndex, int sharding_type>
342 class TensorContractionBlocking<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
343                                 StorageIndex, sharding_type> {
344   // TODO(ezhulenev): Define proper gebp_traits in Eigen for quantized types?
345 
346   // Default Eigen block heuristics for `QInt8xQUInt8 -> QInt32` are wrong.
347   // Mostly because gebp_traits are not correctly defined. But we know that we
348   // are going to use s8u8s32_gemm from DNNL, so we use float heuristics, and
349   // adjust them to work well with DNNL.
350   using LhsScalar = Eigen::QInt8;
351   using RhsScalar = Eigen::QUInt8;
352   using ResScalar = Eigen::QInt32;
353 
354   // Multiply default choice of block size along M, N and K dimensions.
355   static constexpr float kScaleM = 1.5;
356   static constexpr float kScaleN = 1.5;
357   static constexpr float kScaleK = 1.5;
358 
359  public:
360   TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n,
361                             StorageIndex num_threads = 1)
362       : kc_(k), mc_(m), nc_(n) {
363     // Each dimension is a multiple of 32 (fits into _m256i).
364     mc_ = (std::min)(m, static_cast<StorageIndex>(192));
365     nc_ = (std::min)(n, static_cast<StorageIndex>(288));
366     kc_ = (std::min)(k, static_cast<StorageIndex>(320));
367   }
368 
369   EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }
370   EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; }
371   EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; }
372 
373  private:
374   StorageIndex kc_;
375   StorageIndex mc_;
376   StorageIndex nc_;
377 };
378 
379 // If the Lhs or Rhs Tensor expressions are already evaluated and have access to
380 // raw data, we can skip packing step and setup pointers and a stride to the
381 // underlying memory buffer and pass them directly to Gemm.
382 template <typename Scalar, typename StorageIndex>
383 struct ColMajorBlock {
384   bool is_direct_access;
385 
386   // Valid iff `is_direct_access == false`
387   Scalar* packed_data;
388 
389   // Valid iff `is_direct_access == true`
390   Scalar* raw_data;
391   StorageIndex stride;
392   char transpose;
393 };
394 
395 template <typename DataMapper>
396 struct DirectColMajorAccess {
397   enum { value = false };
398 
399   template <typename Scalar, typename StorageIndex>
400   static bool block(const typename DataMapper::SubMapper& data_mapper,
401                     const StorageIndex rows, const StorageIndex cols,
402                     const StorageIndex num_kernels,
403                     ColMajorBlock<Scalar, StorageIndex>* block) {
404     eigen_assert(false && "Not implemented");
405     return false;
406   }
407 };
408 
409 // If we have an access to raw memory of the contraction input, we can safely
410 // skip packing if:
411 //   (1) Packing is a no-op.
412 //   (2) Packed block will be used just once.
413 //
414 // If a packed block is used many times, it's more efficient to pack it into
415 // contiguous block of memory to reduce pressure on TLB.
416 //
417 // TODO(ezhulenev): Add support for more tensor expressions that matters.
418 #define REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_EXPR)                          \
419   template <typename Scalar, typename StorageIndex, int Side, typename Device, \
420             typename nocontract_t, typename contract_t, int packet_size,       \
421             int Alignment>                                                     \
422   struct DirectColMajorAccess<TensorContractionInputMapper<                    \
423       Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>,        \
424       nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true,    \
425       /*inner_dim_reordered=*/false, Alignment>> {                             \
426     enum { value = true };                                                     \
427                                                                                \
428     using DataMapper = TensorContractionInputMapper<                           \
429         Scalar, StorageIndex, Side, TensorEvaluator<TENSOR_EXPR, Device>,      \
430         nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true,  \
431         /*inner_dim_reordered=*/false, Alignment>;                             \
432                                                                                \
433     static bool block(const typename DataMapper::SubMapper& data_mapper,       \
434                       const StorageIndex rows, const StorageIndex cols,        \
435                       const StorageIndex num_kernels,                          \
436                       ColMajorBlock<Scalar, StorageIndex>* block) {            \
437       static_assert(DataMapper::DirectOffsets == true,                         \
438                     "DataMapper must support direct offsets");                 \
439                                                                                \
440       const StorageIndex vert_offset = data_mapper.vert_offset();              \
441       const StorageIndex horiz_offset = data_mapper.horiz_offset();            \
442       const StorageIndex stride =                                              \
443           Side == Lhs ? data_mapper.base_mapper().stride()                     \
444                       : data_mapper.base_mapper().nocontract_strides()[0];     \
445       const Scalar* data = data_mapper.base_mapper().tensor().data();          \
446       data = Side == Lhs ? data : data + vert_offset + horiz_offset * stride;  \
447                                                                                \
448       const bool is_no_op_packing = stride == rows;                            \
449       const StorageIndex addressable_mem = (stride * cols * sizeof(Scalar));   \
450       const bool use_direct_access =                                           \
451           is_no_op_packing || num_kernels == 1 /* used once */ ||              \
452           ((num_kernels == 2) &&                                               \
453            (addressable_mem < (256 << 10) /* 256 kb */));                      \
454                                                                                \
455       if (use_direct_access) {                                                 \
456         block->is_direct_access = true;                                        \
457         block->raw_data = const_cast<Scalar*>(data);                           \
458         block->stride = stride;                                                \
459         block->transpose = 'N';                                                \
460         return true;                                                           \
461       }                                                                        \
462       return false;                                                            \
463     }                                                                          \
464   }
465 
466 #define SIMPLE_TENSOR const Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex>
467 
468 #define TENSOR_MAP_ROWMAJOR                                               \
469   const TensorMap<Tensor<const Scalar, 2, Eigen::RowMajor, StorageIndex>, \
470                   Eigen::Aligned>
471 
472 #define TENSOR_MAP_COLMAJOR                                               \
473   const TensorMap<Tensor<const Scalar, 2, Eigen::ColMajor, StorageIndex>, \
474                   Eigen::Aligned>
475 
476 #define TENSOR_MAP_CONST_ROWMAJOR                                   \
477   const TensorMap<Tensor<Scalar, 2, Eigen::RowMajor, StorageIndex>, \
478                   Eigen::Aligned>
479 
480 #define TENSOR_MAP_CONST_COLMAJOR                                   \
481   const TensorMap<Tensor<Scalar, 2, Eigen::ColMajor, StorageIndex>, \
482                   Eigen::Aligned>
483 
484 // This is reshaped convolution filter from `eigen_spatial_convolutions.h`.
485 #define TENSOR_RESHAPE                                                        \
486   const TensorReshapingOp<                                                    \
487       const Eigen::DSizes<StorageIndex, 2>,                                   \
488       const TensorMap<Tensor<const Scalar, 4, Eigen::RowMajor, StorageIndex>, \
489                       Eigen::Aligned>>
490 
491 REGISTER_DIRECT_COL_MAJOR_ACCESS(SIMPLE_TENSOR);
492 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_ROWMAJOR);
493 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_COLMAJOR);
494 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_ROWMAJOR);
495 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_COLMAJOR);
496 REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_RESHAPE);
497 
498 #undef SIMPLE_TENSOR
499 #undef TENSOR_MAP_ROWMAJOR
500 #undef TENSOR_MAP_COLMAJOR
501 #undef TENSOR_MAP_CONST_ROWMAJOR
502 #undef TENSOR_MAP_CONST_COLMAJOR
503 #undef TENSOR_RESHAPE
504 #undef REGISTER_DIRECT_COL_MAJOR_ACCESS
505 
506 template <typename ResScalar, typename LhsScalar, typename RhsScalar,
507           typename StorageIndex, typename OutputMapper>
508 struct GemmKernelProvider {
509   enum { Defined = 0 };
510   using GemmKernel = void;
511 };
512 
513 template <typename StorageIndex, typename OutputMapper>
514 struct GemmKernelProvider<float, float, float, StorageIndex, OutputMapper> {
515   enum { Defined = 1 };
516   using GemmKernel = dnnl_gemm_kernel<float, StorageIndex, OutputMapper>;
517 };
518 
519 template <typename StorageIndex, typename OutputMapper>
520 struct GemmKernelProvider<Eigen::QInt32, Eigen::QInt8, Eigen::QUInt8,
521                           StorageIndex, OutputMapper> {
522   enum { Defined = 1 };
523   using GemmKernel = mkldnn_gemm_s8u8s32_kernel<StorageIndex, OutputMapper>;
524 };
525 
526 // NOTE: 'std::enable_if' doesn't work for template specializations. See
527 // "default template argument in a class template partial specialization".
528 
529 // Tensor contraction kernel that can fallback on Eigen gebp_kernel at runtime.
530 #define REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(                      \
531     RES_SCALAR, LHS_SCALAR, RHS_SCALAR)                                        \
532                                                                                \
533   template <typename StorageIndex, typename OutputMapper, typename LhsMapper,  \
534             typename RhsMapper>                                                \
535   struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR,           \
536                                  StorageIndex, OutputMapper, LhsMapper,        \
537                                  RhsMapper> {                                  \
538     TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n,    \
539                             StorageIndex bm, StorageIndex bk, StorageIndex bn) \
540         : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {}                          \
541                                                                                \
542     enum { HasBeta = true };                                                   \
543                                                                                \
544     using ResScalar = RES_SCALAR;                                              \
545     using LhsScalar = LHS_SCALAR;                                              \
546     using RhsScalar = RHS_SCALAR;                                              \
547                                                                                \
548     using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>;       \
549                                                                                \
550     using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>;                   \
551     using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>;                   \
552                                                                                \
553     using DirectLhsAccess = DirectColMajorAccess<LhsMapper>;                   \
554     using DirectRhsAccess = DirectColMajorAccess<RhsMapper>;                   \
555                                                                                \
556     /* Packed Lhs/Rhs block memory allocator.*/                                \
557     typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar>           \
558         BlockMemAllocator;                                                     \
559     typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle;         \
560                                                                                \
561     using LhsPacker =                                                          \
562         gemm_pack_colmajor_block<LhsScalar, StorageIndex,                      \
563                                  typename LhsMapper::SubMapper, ColMajor>;     \
564     using RhsPacker =                                                          \
565         gemm_pack_colmajor_block<RhsScalar, StorageIndex,                      \
566                                  typename RhsMapper::SubMapper, ColMajor>;     \
567                                                                                \
568     using GemmKernelProviderType =                                             \
569         GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex,      \
570                            OutputMapper>;                                      \
571     static_assert(                                                             \
572         GemmKernelProviderType::Defined,                                       \
573         "Custom GEMM kernel is not registered for given scalar types");        \
574     using GemmKernel = typename GemmKernelProviderType::GemmKernel;            \
575                                                                                \
576     /* Fallback on default Eigen pack and GEBP kernel if custom contraction */ \
577     /* kernels disabled at runtime.                                         */ \
578     using EigenLhsPacker =                                                     \
579         gemm_pack_lhs<LhsScalar, StorageIndex, typename LhsMapper::SubMapper,  \
580                       Traits::mr, Traits::LhsProgress,                         \
581                       typename Traits::LhsPacket4Packing, ColMajor>;           \
582     using EigenRhsPacker =                                                     \
583         gemm_pack_rhs<RhsScalar, StorageIndex, typename RhsMapper::SubMapper,  \
584                       Traits::nr, ColMajor>;                                   \
585     using GebpKernel =                                                         \
586         gebp_kernel<LhsScalar, RhsScalar, StorageIndex, OutputMapper,          \
587                     Traits::mr, Traits::nr, /*ConjugateLhs*/ false,            \
588                     /*ConjugateRhs*/ false>;                                   \
589                                                                                \
590     template <typename Device>                                                 \
591     EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block,  \
592                                               RhsBlock* rhs_block) {           \
593       return BlockMemAllocator::allocate(                                      \
594           d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data);    \
595     }                                                                          \
596                                                                                \
597     template <typename Device>                                                 \
598     EIGEN_DEVICE_FUNC BlockMemHandle                                           \
599     allocateSlices(Device& d, const int num_lhs, const int num_rhs,            \
600                    const int num_slices, std::vector<LhsBlock>* lhs_blocks,    \
601                    std::vector<RhsBlock>* rhs_blocks) {                        \
602       eigen_assert(num_slices > 0);                                            \
603       std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices);                \
604       std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices);                \
605                                                                                \
606       BlockMemHandle block_mem = BlockMemAllocator::allocateSlices(            \
607           d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(),         \
608           rhs_mem.data());                                                     \
609                                                                                \
610       for (Index x = 0; x < num_slices; x++) {                                 \
611         if (num_lhs > 0) lhs_blocks[x].resize(num_lhs);                        \
612         for (Index m = 0; m < num_lhs; m++) {                                  \
613           lhs_blocks[x][m].packed_data = lhs_mem[x][m];                        \
614         }                                                                      \
615         if (num_rhs > 0) rhs_blocks[x].resize(num_rhs);                        \
616         for (Index n = 0; n < num_rhs; n++) {                                  \
617           rhs_blocks[x][n].packed_data = rhs_mem[x][n];                        \
618         }                                                                      \
619       }                                                                        \
620                                                                                \
621       return block_mem;                                                        \
622     }                                                                          \
623                                                                                \
624     template <typename Device>                                                 \
625     EIGEN_DEVICE_FUNC static void deallocate(Device& d,                        \
626                                              BlockMemHandle handle) {          \
627       BlockMemAllocator::deallocate(d, handle);                                \
628     }                                                                          \
629                                                                                \
630     EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(                          \
631         LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper,  \
632         const StorageIndex depth, const StorageIndex rows) {                   \
633       if (UseCustomContractionKernels()) {                                     \
634         const bool is_direct_access =                                          \
635             DirectLhsAccess::value &&                                          \
636             DirectLhsAccess::block(data_mapper, rows, depth,                   \
637                                    bn > 0 ? divup(n, bn) : 0, lhsBlock);       \
638                                                                                \
639         if (!is_direct_access) {                                               \
640           lhsBlock->is_direct_access = false;                                  \
641           LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth);        \
642         }                                                                      \
643       } else {                                                                 \
644         lhsBlock->is_direct_access = false;                                    \
645         EigenLhsPacker()(lhsBlock->packed_data, data_mapper, depth, rows,      \
646                          /*stride*/ 0, /*offset*/ 0);                          \
647       }                                                                        \
648     }                                                                          \
649                                                                                \
650     EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(                          \
651         RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper,  \
652         const StorageIndex depth, const StorageIndex cols) {                   \
653       if (UseCustomContractionKernels()) {                                     \
654         const bool is_direct_access =                                          \
655             DirectRhsAccess::value &&                                          \
656             DirectRhsAccess::block(data_mapper, depth, cols,                   \
657                                    bm > 0 ? divup(m, bm) : 0, rhsBlock);       \
658                                                                                \
659         if (!is_direct_access) {                                               \
660           rhsBlock->is_direct_access = false;                                  \
661           RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols);        \
662         }                                                                      \
663       } else {                                                                 \
664         rhsBlock->is_direct_access = false;                                    \
665         EigenRhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols);     \
666       }                                                                        \
667     }                                                                          \
668                                                                                \
669     EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(                           \
670         const OutputMapper& output_mapper, const LhsBlock& lhsBlock,           \
671         const RhsBlock& rhsBlock, const StorageIndex rows,                     \
672         const StorageIndex depth, const StorageIndex cols, const float alpha,  \
673         const float beta) {                                                    \
674       if (UseCustomContractionKernels()) {                                     \
675         if ((DirectLhsAccess::value && lhsBlock.is_direct_access) &&           \
676             (DirectRhsAccess::value && rhsBlock.is_direct_access)) {           \
677           GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data,    \
678                        rows, depth, cols, alpha, beta,                         \
679                        /*ldA=*/lhsBlock.stride, /*ldB=*/rhsBlock.stride,       \
680                        /*transposeA=*/lhsBlock.transpose,                      \
681                        /*transposeB=*/rhsBlock.transpose);                     \
682                                                                                \
683         } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) {      \
684           GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \
685                        rows, depth, cols, alpha, beta,                         \
686                        /*ldA=*/lhsBlock.stride,                                \
687                        /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions,  \
688                        /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \
689                                                                                \
690         } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) {      \
691           GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \
692                        rows, depth, cols, alpha, beta,                         \
693                        /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions,  \
694                        /*ldB=*/rhsBlock.stride, /*transposeA=*/'N',            \
695                        /*transposeB=*/rhsBlock.transpose);                     \
696                                                                                \
697         } else {                                                               \
698           GemmKernel()(output_mapper, lhsBlock.packed_data,                    \
699                        rhsBlock.packed_data, rows, depth, cols, alpha, beta);  \
700         }                                                                      \
701       } else {                                                                 \
702         /* Gebp kernel does not support beta, so we have to clear memory in */ \
703         /* the output mapper manually.                                      */ \
704         /* WARNING(ezhulenev): This is optimized into a memset in a loop,   */ \
705         /* could be much slower for small matrices. Currently this code     */ \
706         /* path used only for testing, and performance does not matter.     */ \
707         if (beta == 0.0) {                                                     \
708           for (StorageIndex col = 0; col < cols; ++col) {                      \
709             ResScalar* output_base = &output_mapper(0, col);                   \
710             typedef Array<ResScalar, Dynamic, 1> OutputRow;                    \
711             typedef Map<OutputRow, 0, InnerStride<1>> OutputRowMap;            \
712             OutputRowMap(output_base, rows).setZero();                         \
713           }                                                                    \
714         }                                                                      \
715                                                                                \
716         GebpKernel()(                                                          \
717             output_mapper, lhsBlock.packed_data, rhsBlock.packed_data, rows,   \
718             depth, cols, alpha,                                                \
719             /*strideA*/ GemmKernel::kComputeStrideFromBlockDimensions,         \
720             /*strideB*/ GemmKernel::kComputeStrideFromBlockDimensions,         \
721             /*offsetA*/ 0, /*offsetB*/ 0);                                     \
722       }                                                                        \
723     }                                                                          \
724                                                                                \
725    private:                                                                    \
726     /* These are dimensions of the original Tensors, and selected block     */ \
727     /* sizes. The actual block sizes passed to all function above might be  */ \
728     /* smaller because of the partial blocks at the end.                    */ \
729     const StorageIndex m;                                                      \
730     const StorageIndex k;                                                      \
731     const StorageIndex n;                                                      \
732     const StorageIndex bm;                                                     \
733     const StorageIndex bk;                                                     \
734     const StorageIndex bn;                                                     \
735   }
736 
737 // Tensor contraction kernel that do not fallback on Eigen. Currently not all
738 // data types are supported by Eigen data packing and default gebp_kernel.
739 #define REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(RES_SCALAR, LHS_SCALAR, \
740                                                        RHS_SCALAR)             \
741                                                                                \
742   template <typename StorageIndex, typename OutputMapper, typename LhsMapper,  \
743             typename RhsMapper>                                                \
744   struct TensorContractionKernel<RES_SCALAR, LHS_SCALAR, RHS_SCALAR,           \
745                                  StorageIndex, OutputMapper, LhsMapper,        \
746                                  RhsMapper> {                                  \
747     TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n,    \
748                             StorageIndex bm, StorageIndex bk, StorageIndex bn) \
749         : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {}                          \
750                                                                                \
751     enum { HasBeta = true };                                                   \
752                                                                                \
753     using ResScalar = RES_SCALAR;                                              \
754     using LhsScalar = LHS_SCALAR;                                              \
755     using RhsScalar = RHS_SCALAR;                                              \
756                                                                                \
757     using Traits = typename internal::gebp_traits<LhsScalar, RhsScalar>;       \
758                                                                                \
759     using LhsBlock = ColMajorBlock<LhsScalar, StorageIndex>;                   \
760     using RhsBlock = ColMajorBlock<RhsScalar, StorageIndex>;                   \
761                                                                                \
762     using DirectLhsAccess = DirectColMajorAccess<LhsMapper>;                   \
763     using DirectRhsAccess = DirectColMajorAccess<RhsMapper>;                   \
764                                                                                \
765     /* Packed Lhs/Rhs block memory allocator.*/                                \
766     typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar>           \
767         BlockMemAllocator;                                                     \
768     typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle;         \
769                                                                                \
770     using LhsPacker =                                                          \
771         gemm_pack_colmajor_block<LhsScalar, StorageIndex,                      \
772                                  typename LhsMapper::SubMapper, ColMajor>;     \
773     using RhsPacker =                                                          \
774         gemm_pack_colmajor_block<RhsScalar, StorageIndex,                      \
775                                  typename RhsMapper::SubMapper, ColMajor>;     \
776                                                                                \
777     using GemmKernelProviderType =                                             \
778         GemmKernelProvider<ResScalar, LhsScalar, RhsScalar, StorageIndex,      \
779                            OutputMapper>;                                      \
780     static_assert(                                                             \
781         GemmKernelProviderType::Defined,                                       \
782         "Custom GEMM kernel is not registered for given scalar types");        \
783     using GemmKernel = typename GemmKernelProviderType::GemmKernel;            \
784                                                                                \
785     template <typename Device>                                                 \
786     EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block,  \
787                                               RhsBlock* rhs_block) {           \
788       return BlockMemAllocator::allocate(                                      \
789           d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data);    \
790     }                                                                          \
791                                                                                \
792     template <typename Device>                                                 \
793     EIGEN_DEVICE_FUNC BlockMemHandle                                           \
794     allocateSlices(Device& d, const int num_lhs, const int num_rhs,            \
795                    const int num_slices, std::vector<LhsBlock>* lhs_blocks,    \
796                    std::vector<RhsBlock>* rhs_blocks) {                        \
797       eigen_assert(num_slices > 0);                                            \
798       std::vector<std::vector<LhsScalar*>> lhs_mem(num_slices);                \
799       std::vector<std::vector<RhsScalar*>> rhs_mem(num_slices);                \
800                                                                                \
801       BlockMemHandle block_mem = BlockMemAllocator::allocateSlices(            \
802           d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(),         \
803           rhs_mem.data());                                                     \
804                                                                                \
805       for (Index x = 0; x < num_slices; x++) {                                 \
806         if (num_lhs > 0) lhs_blocks[x].resize(num_lhs);                        \
807         for (Index m = 0; m < num_lhs; m++) {                                  \
808           lhs_blocks[x][m].packed_data = lhs_mem[x][m];                        \
809         }                                                                      \
810         if (num_rhs > 0) rhs_blocks[x].resize(num_rhs);                        \
811         for (Index n = 0; n < num_rhs; n++) {                                  \
812           rhs_blocks[x][n].packed_data = rhs_mem[x][n];                        \
813         }                                                                      \
814       }                                                                        \
815                                                                                \
816       return block_mem;                                                        \
817     }                                                                          \
818                                                                                \
819     template <typename Device>                                                 \
820     EIGEN_DEVICE_FUNC static void deallocate(Device& d,                        \
821                                              BlockMemHandle handle) {          \
822       BlockMemAllocator::deallocate(d, handle);                                \
823     }                                                                          \
824                                                                                \
825     EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(                          \
826         LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper,  \
827         const StorageIndex depth, const StorageIndex rows) {                   \
828       const bool is_direct_access =                                            \
829           DirectLhsAccess::value &&                                            \
830           DirectLhsAccess::block(data_mapper, rows, depth,                     \
831                                  bn > 0 ? divup(n, bn) : 0, lhsBlock);         \
832                                                                                \
833       if (!is_direct_access) {                                                 \
834         lhsBlock->is_direct_access = false;                                    \
835         LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth);          \
836       }                                                                        \
837     }                                                                          \
838                                                                                \
839     EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(                          \
840         RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper,  \
841         const StorageIndex depth, const StorageIndex cols) {                   \
842       const bool is_direct_access =                                            \
843           DirectRhsAccess::value &&                                            \
844           DirectRhsAccess::block(data_mapper, depth, cols,                     \
845                                  bm > 0 ? divup(m, bm) : 0, rhsBlock);         \
846                                                                                \
847       if (!is_direct_access) {                                                 \
848         rhsBlock->is_direct_access = false;                                    \
849         RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols);          \
850       }                                                                        \
851     }                                                                          \
852                                                                                \
853     EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(                           \
854         const OutputMapper& output_mapper, const LhsBlock& lhsBlock,           \
855         const RhsBlock& rhsBlock, const StorageIndex rows,                     \
856         const StorageIndex depth, const StorageIndex cols, const float alpha,  \
857         const float beta) {                                                    \
858       if ((DirectLhsAccess::value && lhsBlock.is_direct_access) &&             \
859           (DirectRhsAccess::value && rhsBlock.is_direct_access)) {             \
860         GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data,      \
861                      rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride,  \
862                      /*ldB=*/rhsBlock.stride,                                  \
863                      /*transposeA=*/lhsBlock.transpose,                        \
864                      /*transposeB=*/rhsBlock.transpose);                       \
865                                                                                \
866       } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) {        \
867         GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data,   \
868                      rows, depth, cols, alpha, beta, /*ldA=*/lhsBlock.stride,  \
869                      /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions,    \
870                      /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N');   \
871                                                                                \
872       } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) {        \
873         GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data,   \
874                      rows, depth, cols, alpha, beta,                           \
875                      /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions,    \
876                      /*ldB=*/rhsBlock.stride, /*transposeA=*/'N',              \
877                      /*transposeB=*/rhsBlock.transpose);                       \
878                                                                                \
879       } else {                                                                 \
880         GemmKernel()(output_mapper, lhsBlock.packed_data,                      \
881                      rhsBlock.packed_data, rows, depth, cols, alpha, beta);    \
882       }                                                                        \
883     }                                                                          \
884                                                                                \
885    private:                                                                    \
886     /* These are dimensions of the original Tensors, and selected block     */ \
887     /* sizes. The actual block sizes passed to all function above might be  */ \
888     /* smaller because of the partial blocks at the end.                    */ \
889     const StorageIndex m;                                                      \
890     const StorageIndex k;                                                      \
891     const StorageIndex n;                                                      \
892     const StorageIndex bm;                                                     \
893     const StorageIndex bk;                                                     \
894     const StorageIndex bn;                                                     \
895   }
896 
897 REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(float, float, float);
898 REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(Eigen::QInt32, Eigen::QInt8,
899                                                Eigen::QUInt8);
900 
901 #undef REGISTER_TENSOR_CONTRACTION_KERNEL
902 
903 #endif  // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
904 
905 }  // namespace internal
906 }  // namespace Eigen
907 
908 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_CONTRACTION_KERNEL_H_
909