• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18 #ifdef INTEL_MKL
19 
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "mkldnn.hpp"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/graph/mkl_graph_util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/util/env_var.h"
39 #include "tensorflow/core/util/mkl_threadpool.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42 
43 using mkldnn::engine;
44 using mkldnn::memory;
45 using mkldnn::primitive;
46 using mkldnn::reorder;
47 using mkldnn::stream;
48 using CPUDevice = Eigen::ThreadPoolDevice;
49 using MemoryArgsMap = std::unordered_map<int, memory>;
50 using ReorderPd = mkldnn::reorder::primitive_desc;
51 
52 #ifdef _WIN32
53 typedef unsigned int uint;
54 #endif
55 
56 namespace tensorflow {
57 
58 // The file contains a number of utility classes and functions used by MKL
59 // enabled kernels
60 
61 // This class encapsulates all the meta data that is associated with an MKL
62 // tensor. A tensor is an MKL tensor if it was created as the result of an
63 // MKL operation, and did not go through a conversion to a standard
64 // Tensorflow tensor.
65 
66 // The dimensions order that MKL-DNN internally uses for 2D activations
67 // [Batch, Channel, Height, Width] and
68 // for 2D filters [Out_Channel, In_Channel, Height, Width].
69 typedef enum {
70   Dim_N = 0,
71   Dim_C = 1,
72   Dim_H = 2,
73   Dim_W = 3,
74   Dim_O = 0,
75   Dim_I = 1
76 } MklDnnDims;
77 
78 // The dimensions order that MKL-DNN internally uses for 3D activations
79 // [Batch, Channel, Depth, Height, Width] and
80 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
81 typedef enum {
82   Dim3d_N = 0,
83   Dim3d_C = 1,
84   Dim3d_D = 2,
85   Dim3d_H = 3,
86   Dim3d_W = 4,
87   Dim3d_O = 0,
88   Dim3d_I = 1
89 } MklDnnDims3D;
90 
91 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
92 // filter_width, in_channels, out_channels]
93 typedef enum {
94   TF_2DFILTER_DIM_H = 0,
95   TF_2DFILTER_DIM_W = 1,
96   TF_2DFILTER_DIM_I = 2,
97   TF_2DFILTER_DIM_O = 3
98 } TFFilterDims2d;
99 
100 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
101 // filter_height, filter_width, in_channels, out_channels]
102 typedef enum {
103   TF_3DFILTER_DIM_P = 0,
104   TF_3DFILTER_DIM_H = 1,
105   TF_3DFILTER_DIM_W = 2,
106   TF_3DFILTER_DIM_I = 3,
107   TF_3DFILTER_DIM_O = 4
108 } TFFilterDims3d;
109 
110 // The dimensions order that MKL-DNN requires for the filter in a grouped
111 // convolution (2D only)
112 typedef enum {
113   MKL_GROUP_FILTER_DIM_G = 0,
114   MKL_GROUP_FILTER_DIM_O = 1,
115   MKL_GROUP_FILTER_DIM_I = 2,
116   MKL_GROUP_FILTER_DIM_H = 3,
117   MKL_GROUP_FILTER_DIM_W = 4
118 } MklDnnFilterGroupDims;
119 
120 // Enum used to templatize MklOp kernel implementation
121 // that support both fp32 and int8 versions.
122 enum class MklQuantization {
123   QUANTIZED_VERSION,
124   FP_VERSION,
125 };
126 
127 static const int kSmallBatchSize = 32;
128 
execute_primitives(std::vector<mkldnn::primitive> & primitives,std::shared_ptr<stream> stream,std::vector<std::unordered_map<int,memory>> & net_args)129 inline void execute_primitives(
130     std::vector<mkldnn::primitive>& primitives, std::shared_ptr<stream> stream,
131     std::vector<std::unordered_map<int, memory>>& net_args) {
132   DCHECK_EQ(primitives.size(), net_args.size());
133   for (size_t i = 0; i < primitives.size(); ++i) {
134     primitives.at(i).execute(*stream, net_args.at(i));
135   }
136 }
137 
138 // In MKL-DNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor
139 // (md) structure will no longer be recorded in its `format` field. Instead, it
140 // will be set to a canonical `blocked` format for every fully described md.
141 //
142 // Currently, we query this `format` field while mapping MKL-DNN's data format
143 // to TF's data format. Due to the above restriction, we will now get this data
144 // format information from TF's `data_format` attribute (i.e. via
145 // `TensorFormat`) for MKL-DNN v1.x.
146 //
147 // Some MKL-DNN operators such as ReLU do not have a `data_format` attribute
148 // since they are usually in `blocked` format. Therefore, in order to
149 // distinguish between blocked and non-blocked formats, we have defined a new
150 // enum called `MklTensorFormat` that is semantically similar to `TensorFormat`
151 // but with the following additional fields namely:
152 //  1) FORMAT_BLOCKED: as described above, this is needed for element-wise
153 //     operators such as ReLU.
154 //  2) FORMAT_INVALID: for error-checking (ex. unsupported format)
155 //  3) FORMAT_X, FORMAT_NC, FORMAT_TNC: to distinguish between MKL tensors based
156 //     on their dimensions in operators such as Softmax, i.e.:
157 //        FORMAT_X   - 1D tensor
158 //        FORMAT_NC  - 2D tensor
159 //        FORMAT_TNC - 3D tensor
160 enum class MklTensorFormat {
161   FORMAT_NHWC = 0,
162   FORMAT_NCHW = 1,
163   FORMAT_NDHWC = 2,
164   FORMAT_NCDHW = 3,
165   FORMAT_X = 4,
166   FORMAT_NC = 5,
167   FORMAT_TNC = 6,
168   FORMAT_BLOCKED = 7,
169   FORMAT_INVALID = 8,
170 };
171 
172 // Forward declarations
173 memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
174 
175 TensorFormat MklDnn3DDataFormatToTFDataFormat(MklTensorFormat format);
176 TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format);
177 
178 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
179 Status CreateBlockedMemDescHelper(const memory::dims& dim,
180                                   const memory::dims& strides,
181                                   memory::data_type dtype,
182                                   mkldnn_memory_desc_t* blocked_md);
183 
184 inline std::ostream& operator<<(std::ostream& os,
185                                 const memory::format_tag& tag) {
186   if (tag == memory::format_tag::undef) {
187     os << "undef";
188   } else if (tag == memory::format_tag::any) {
189     os << "any";
190   } else {
191     os << "invalid";
192   }
193   return os;
194 }
195 
196 inline void operator<<(std::ostream& os, const MklTensorFormat& format) {
197   if (format == MklTensorFormat::FORMAT_NHWC) {
198     os << "FORMAT_NHWC";
199   } else if (format == MklTensorFormat::FORMAT_NCHW) {
200     os << "FORMAT_NCHW";
201   } else if (format == MklTensorFormat::FORMAT_NDHWC) {
202     os << "FORMAT_NDHWC";
203   } else if (format == MklTensorFormat::FORMAT_NCDHW) {
204     os << "FORMAT_NCDHW";
205   } else if (format == MklTensorFormat::FORMAT_X) {
206     os << "FORMAT_X";
207   } else if (format == MklTensorFormat::FORMAT_NC) {
208     os << "FORMAT_NC";
209   } else if (format == MklTensorFormat::FORMAT_TNC) {
210     os << "FORMAT_TNC";
211   } else if (format == MklTensorFormat::FORMAT_BLOCKED) {
212     os << "FORMAT_BLOCKED";
213   } else {
214     os << "INVALID FORMAT";
215   }
216 }
217 
218 template <typename T>
array_cmp(const T * a1,const T * a2,size_t size)219 inline bool array_cmp(const T* a1, const T* a2, size_t size) {
220   for (size_t i = 0; i < size; ++i)
221     if (a1[i] != a2[i]) return false;
222   return true;
223 }
224 
CreateStream(MklDnnThreadPool * eigen_tp,const engine & engine)225 inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp,
226                                     const engine& engine) {
227 #ifndef ENABLE_ONEDNN_OPENMP
228   if (eigen_tp != nullptr) {
229     stream* tp_stream =
230         new stream(dnnl::threadpool_interop::make_stream(engine, eigen_tp));
231     return tp_stream;
232   } else {
233     stream* tp_stream = new stream(engine);
234     return tp_stream;
235   }
236 #else
237   stream* tp_stream = new stream(engine);
238   return tp_stream;
239 #endif  // !ENABLE_ONEDNN_OPENMP
240 }
241 
242 class MklDnnShape {
243  private:
244   struct MklShapeData {
245     // Flag to indicate if the tensor is an MKL tensor or not
246     bool is_mkl_tensor_ = false;
247     // Number of dimensions in Tensorflow format
248     size_t dimension_ = 0;
249     mkldnn_dims_t sizes_;  // Required by MKL for conversions
250     MklTensorFormat tf_data_format_ = MklTensorFormat::FORMAT_BLOCKED;
251     memory::data_type T_ = memory::data_type::undef;
252     // MKL layout
253     mkldnn_memory_desc_t mkl_md_;
254     /// TF dimension corresponding to this MKL dimension
255     mkldnn_dims_t map_;
256   };
257   MklShapeData data_;
258 
259   typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
260 
261 #define INVALID_DIM_SIZE -1
262 
263  public:
MklDnnShape()264   MklDnnShape() {
265     for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
266          ++i) {
267       data_.sizes_[i] = -1;
268     }
269     for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
270       data_.map_[i] = -1;
271     }
272   }
273 
~MklDnnShape()274   ~MklDnnShape() {}
275   TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape);  // Cannot copy
276 
277   /// Equality function for MklDnnShape objects
278   /// @return true if both are equal; false otherwise.
279   inline bool operator==(const MklDnnShape& input_shape) const {
280     if (this->IsMklTensor() != input_shape.IsMklTensor()) {
281       return false;
282     }
283 
284     // If input tensors are in MKL layout, then we check for dimensions and
285     // sizes.
286     if (this->IsMklTensor()) {
287       const mkldnn_memory_desc_t& cur_md = (this->GetMklLayout()).data;
288       const mkldnn_memory_desc_t& input_shape_md =
289           input_shape.GetMklLayout().data;
290       return this->GetTfShape() == input_shape.GetTfShape() &&
291              mkldnn_memory_desc_equal(&cur_md, &input_shape_md);
292     }
293 
294     // Both inputs are not MKL tensors.
295     return true;
296   }
297 
298   /// Equality operator for MklDnnShape and TFShape.
299   /// Returns: true if TF shapes for both are the same, false otherwise
300   inline bool operator==(const TensorShape& input_shape) const {
301     if (!this->IsMklTensor()) {
302       return false;
303     }
304 
305     return this->GetTfShape() == input_shape;
306   }
307 
IsMklTensor()308   inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
SetMklTensor(bool is_mkl_tensor)309   inline void SetMklTensor(bool is_mkl_tensor) {
310     data_.is_mkl_tensor_ = is_mkl_tensor;
311   }
312 
SetDimensions(const size_t dimension)313   inline void SetDimensions(const size_t dimension) {
314     data_.dimension_ = dimension;
315   }
GetDimension(char dimension)316   inline size_t GetDimension(char dimension) const {
317     int index = GetMklDnnTensorDimIndex(dimension);
318     CHECK(index >= 0 && index < this->GetDimension())
319         << "Invalid index from the dimension: " << index << ", " << dimension;
320     return this->DimSize(index);
321   }
322 
GetDimension3D(char dimension)323   inline size_t GetDimension3D(char dimension) const {
324     int index = GetMklDnnTensor3DDimIndex(dimension);
325     CHECK(index >= 0 && index < this->GetDimension())
326         << "Invalid index from the dimension: " << index << ", " << dimension;
327     return this->DimSize(index);
328   }
329 
GetMklDnnTensorDimIndex(char dimension)330   inline int32 GetMklDnnTensorDimIndex(char dimension) const {
331     switch (dimension) {
332       case 'N':
333         return MklDnnDims::Dim_N;
334       case 'C':
335         return MklDnnDims::Dim_C;
336       case 'H':
337         return MklDnnDims::Dim_H;
338       case 'W':
339         return MklDnnDims::Dim_W;
340       default:
341         LOG(FATAL) << "Invalid dimension: " << dimension;
342         return -1;  // Avoid compiler warning about missing return value
343     }
344   }
345 
GetMklDnnTensor3DDimIndex(char dimension)346   inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
347     switch (dimension) {
348       case 'N':
349         return MklDnnDims3D::Dim3d_N;
350       case 'C':
351         return MklDnnDims3D::Dim3d_C;
352       case 'D':
353         return MklDnnDims3D::Dim3d_D;
354       case 'H':
355         return MklDnnDims3D::Dim3d_H;
356       case 'W':
357         return MklDnnDims3D::Dim3d_W;
358       default:
359         LOG(FATAL) << "Invalid dimension: " << dimension;
360         return -1;  // Avoid compiler warning about missing return value
361     }
362   }
363 
GetDimension()364   inline size_t GetDimension() const { return data_.dimension_; }
GetSizes()365   inline const int* GetSizes() const {
366     return reinterpret_cast<const int*>(&data_.sizes_[0]);
367   }
368 
369   // Returns an mkldnn::memory::dims object that contains the sizes of this
370   // MklDnnShape object.
GetSizesAsMklDnnDims()371   inline memory::dims GetSizesAsMklDnnDims() const {
372     memory::dims retVal;
373     if (data_.is_mkl_tensor_) {
374       size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
375       for (size_t i = 0; i < dimensions; i++) {
376         if (data_.sizes_[i] != INVALID_DIM_SIZE)
377           retVal.push_back(data_.sizes_[i]);
378       }
379     } else {
380       CHECK_EQ(data_.is_mkl_tensor_, true);
381     }
382     return retVal;
383   }
384 
DimSize(int index)385   inline int64 DimSize(int index) const {
386     CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
387     return data_.sizes_[index];
388   }
389 
390   /// Return TensorShape that describes the Tensorflow shape of the tensor
391   /// represented by this MklShape.
GetTfShape()392   inline TensorShape GetTfShape() const {
393     CHECK_EQ(data_.is_mkl_tensor_, true);
394 
395     std::vector<int32> shape(data_.dimension_, -1);
396     // As mentioned in the comment above, we now rely on TF's `data_format`
397     // attribute to determine if TF shape is in blocked format or not.
398     if (data_.tf_data_format_ != MklTensorFormat::FORMAT_BLOCKED) {
399       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
400         shape[idx] = data_.sizes_[TfDimIdx(idx)];
401       }
402     } else {
403       // If Tensorflow shape is in Blocked format, then we don't have dimension
404       // map for it. So we just create Tensorflow shape from sizes in the
405       // specified order.
406       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
407         shape[idx] = data_.sizes_[idx];
408       }
409     }
410 
411     TensorShape ts;
412     bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
413     CHECK_EQ(ret, true);
414     return ts;
415   }
416 
SetElemType(memory::data_type dt)417   inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
GetElemType()418   inline const memory::data_type GetElemType() { return data_.T_; }
419 
SetMklLayout(memory::desc * md)420   inline void SetMklLayout(memory::desc* md) {
421     CHECK_NOTNULL(md);
422     data_.mkl_md_ = md->data;
423   }
424 
GetMklLayout()425   inline const memory::desc GetMklLayout() const {
426     return memory::desc(data_.mkl_md_);
427   }
428 
GetTfDataFormat()429   inline MklTensorFormat GetTfDataFormat() const {
430     return data_.tf_data_format_;
431   }
432 
433   /// We don't create primitive_descriptor for TensorFlow layout now.
434   /// We use lazy evaluation and create it only when needed. Input format can
435   /// also be Blocked format.
SetTfLayout(size_t dims,const memory::dims & sizes,MklTensorFormat format)436   inline void SetTfLayout(size_t dims, const memory::dims& sizes,
437                           MklTensorFormat format) {
438     DCHECK_EQ(dims, sizes.size())
439         << "SetTfLayout: Number of dimensions does not"
440            "match with dimension array";
441     data_.dimension_ = dims;
442     for (size_t ii = 0; ii < dims; ++ii) {
443       data_.sizes_[ii] = sizes[ii];
444     }
445     data_.tf_data_format_ = format;
446     if (format != MklTensorFormat::FORMAT_BLOCKED) {
447       if (dims == 2) {
448         data_.map_[0] = MklDnnDims::Dim_N;
449         data_.map_[1] = MklDnnDims::Dim_C;
450       } else {
451         SetTfDimOrder(dims, format);
452       }
453     }
454   }
455 
GetTfLayout()456   inline const memory::desc GetTfLayout() const {
457     memory::dims dims;
458     for (size_t ii = 0; ii < data_.dimension_; ++ii) {
459       dims.push_back(data_.sizes_[ii]);
460     }
461 
462     // Create Blocked memory desc if input TF format was set like that.
463     if (data_.tf_data_format_ == MklTensorFormat::FORMAT_BLOCKED) {
464       auto strides = CalculateTFStrides(dims);
465       mkldnn_memory_desc_t blocked_md;
466       TF_CHECK_OK(
467           CreateBlockedMemDescHelper(dims, strides, data_.T_, &blocked_md));
468       return memory::desc(blocked_md);
469     } else {
470       auto format_tag =
471           MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_);
472       return memory::desc(dims, data_.T_, format_tag);
473     }
474   }
475 
GetCurLayout()476   inline const memory::desc GetCurLayout() const {
477     return IsMklTensor() ? GetMklLayout() : GetTfLayout();
478   }
479 
480   // We don't need a case of default dimension order because
481   // when an operator that does not get data_format attribute gets all inputs
482   // in Tensorflow format, it will produce output in Tensorflow format.
SetTfDimOrder(const size_t dimension,const mkldnn_dims_t map)483   inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
484     CHECK(dimension == data_.dimension_);
485     for (size_t ii = 0; ii < dimension; ii++) {
486       data_.map_[ii] = map[ii];
487     }
488   }
489 
SetTfDimOrder(const size_t dimension,TensorFormat data_format)490   inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
491     if (dimension == 5) {
492       CHECK(dimension == data_.dimension_);
493       data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
494           MklDnnDims3D::Dim3d_D;
495       data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
496           MklDnnDims3D::Dim3d_H;
497       data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
498           MklDnnDims3D::Dim3d_W;
499       data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
500           MklDnnDims3D::Dim3d_C;
501       data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
502           MklDnnDims3D::Dim3d_N;
503     } else {
504       CHECK_EQ(dimension, 4);
505       CHECK(dimension == data_.dimension_);
506       data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
507       data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
508       data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
509       data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
510     }
511   }
512 
SetTfDimOrder(const size_t dimension,MklTensorFormat format)513   inline void SetTfDimOrder(const size_t dimension, MklTensorFormat format) {
514     TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
515     SetTfDimOrder(dimension, data_format);
516   }
517 
GetTfToMklDimMap()518   inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
TfDimIdx(int index)519   inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
TfDimSize(int index)520   inline int64 TfDimSize(int index) const {
521     return data_.sizes_[TfDimIdx(index)];
522   }
523 
524   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
525   /// corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)526   inline bool IsMklChannelDim(int d) const {
527     return TfDimIdx(d) == MklDnnDims::Dim_C;
528   }
529 
530   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
531   /// corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)532   inline bool IsMklBatchDim(int d) const {
533     return TfDimIdx(d) == MklDnnDims::Dim_N;
534   }
535 
536   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
537   /// corresponds to MKL's Width dimension.
IsMklWidthDim(int d)538   inline bool IsMklWidthDim(int d) const {
539     return TfDimIdx(d) == MklDnnDims::Dim_W;
540   }
541   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
542   /// corresponds to MKL's Height dimension.
IsMklHeightDim(int d)543   inline bool IsMklHeightDim(int d) const {
544     return TfDimIdx(d) == MklDnnDims::Dim_H;
545   }
546 
547   /// Check if the TF-MKL dimension ordering map specifies if the input
548   /// tensor is in NCHW format.
IsTensorInNCHWFormat()549   inline bool IsTensorInNCHWFormat() const {
550     TensorFormat data_format = FORMAT_NCHW;
551     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
552             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
553             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
554             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
555   }
556 
557   /// Check if the TF-MKL dimension ordering map specifies if the input
558   /// tensor is in NHWC format.
IsTensorInNHWCFormat()559   inline bool IsTensorInNHWCFormat() const {
560     TensorFormat data_format = FORMAT_NHWC;
561     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
562             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
563             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
564             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
565   }
566 
567   /// The following methods are used for serializing and de-serializing the
568   /// contents of the mklshape object.
569   /// The data is serialized in this order
570   /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
571 
572   /// Size of buffer to hold the serialized object, the size is computed by
573   /// following above mentioned order
GetSerializeBufferSize()574   inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
575 
SerializeMklDnnShape(unsigned char * buf,size_t buf_size)576   void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
577     CHECK(buf_size >= GetSerializeBufferSize())
578         << "Buffer size is too small to SerializeMklDnnShape";
579     *reinterpret_cast<MklShapeData*>(buf) = data_;
580   }
581 
DeSerializeMklDnnShape(const unsigned char * buf,size_t buf_size)582   void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
583     // Make sure buffer holds at least is_mkl_tensor_.
584     CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
585         << "Buffer size is too small in DeSerializeMklDnnShape";
586 
587     const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
588     if (is_mkl_tensor) {  // If it is an MKL Tensor then read the rest
589       CHECK(buf_size >= GetSerializeBufferSize())
590           << "Buffer size is too small in DeSerializeMklDnnShape";
591       data_ = *reinterpret_cast<const MklShapeData*>(buf);
592     }
593   }
594 };
595 
596 // List of MklShape objects. Used in Concat/Split layers.
597 typedef std::vector<MklDnnShape> MklDnnShapeList;
598 
599 template <typename T>
600 class MklDnnData;
601 
602 // TODO merge with the execute_primitives.
603 inline void ExecutePrimitive(const std::vector<primitive>& net,
604                              const std::vector<MemoryArgsMap>* net_args,
605                              const engine& cpu_engine,
606                              OpKernelContext* context = nullptr) {
607   DCHECK(net_args);
608   DCHECK_EQ(net.size(), net_args->size());
609   std::unique_ptr<stream> cpu_stream;
610   MklDnnThreadPool eigen_tp;
611   if (context != nullptr) {
612     eigen_tp = MklDnnThreadPool(context);
613     cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));
614   } else {
615     cpu_stream.reset(CreateStream(nullptr, cpu_engine));
616   }
617   for (size_t i = 0; i < net.size(); ++i) {
618     net.at(i).execute(*cpu_stream, net_args->at(i));
619   }
620   cpu_stream->wait();
621 }
622 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & input_mkl_tensor,const MklDnnShape & input_mkl_shape,Tensor * output_tf_tensor)623 inline Status ConvertMklToTF(OpKernelContext* context,
624                              const Tensor& input_mkl_tensor,
625                              const MklDnnShape& input_mkl_shape,
626                              Tensor* output_tf_tensor) {
627   try {
628     if (!input_mkl_shape.IsMklTensor()) {
629       // Return input as is since it is already a TF tensor
630       *output_tf_tensor = input_mkl_tensor;
631       return Status::OK();
632     }
633 
634     // Allocate output tensor.
635     TensorShape output_tf_shape = input_mkl_shape.GetTfShape();
636     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<T>::v(), output_tf_shape,
637                                        output_tf_tensor));
638 
639     engine cpu_engine(engine::kind::cpu, 0);
640     MklDnnData<T> input(&cpu_engine);
641 
642     // Get MKL layout of input tensor.
643     auto input_mkl_md = input_mkl_shape.GetMklLayout();
644     auto output_tf_md = input_mkl_shape.GetTfLayout();
645     input.SetUsrMem(input_mkl_md, &input_mkl_tensor);
646 
647     if (input.IsReorderNeeded(output_tf_md)) {
648       std::vector<primitive> net;
649       std::vector<MemoryArgsMap> net_args;
650       bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor,
651                                               net, net_args, cpu_engine);
652       if (!status) {
653         return Status(error::Code::INTERNAL,
654                       "ConvertMklToTF(): Failed to create reorder for input");
655       }
656       ExecutePrimitive(net, &net_args, cpu_engine, context);
657     } else {
658       // If not, just forward input tensor to output tensor.
659       bool status =
660           output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape);
661       if (!status) {
662         return Status(
663             error::Code::INTERNAL,
664             "ConvertMklToTF(): Failed to forward input tensor to output");
665       }
666     }
667     return Status::OK();
668   } catch (mkldnn::error& e) {
669     string error_msg = "Status: " + std::to_string(e.status) +
670                        ", message: " + string(e.message) + ", in file " +
671                        string(__FILE__) + ":" + std::to_string(__LINE__);
672     LOG(FATAL) << "Operation received an exception: " << error_msg;
673   }
674 }
675 
676 // Get the MKL shape from the second string tensor
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape,bool eager_mode)677 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
678                         bool eager_mode) {
679   if (!eager_mode) {
680     mklshape->DeSerializeMklDnnShape(
681         ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
682             .flat<uint8>()
683             .data(),
684         ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
685                 .flat<uint8>()
686                 .size() *
687             sizeof(uint8));
688   } else {
689     mklshape->SetMklTensor(false);
690   }
691 }
692 
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape)693 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
694   GetMklShape(ctext, n, mklshape, false);
695 }
696 
697 // Gets the actual input
MklGetInput(OpKernelContext * ctext,int n)698 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
699   return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
700 }
701 
GetMklInputList(OpKernelContext * ctext,StringPiece name,OpInputList * input_tensors)702 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
703                             OpInputList* input_tensors) {
704   CHECK_NOTNULL(input_tensors);
705   TF_CHECK_OK(ctext->input_list(name, input_tensors));
706 }
707 
708 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
709                             MklDnnShapeList* mkl_shapes,
710                             bool native_format = false) {
711   if (!native_format) {
712     OpInputList input_mkl_tensors;
713     GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
714 
715     for (int i = 0; i < input_mkl_tensors.size(); i++) {
716       (*mkl_shapes)[i].DeSerializeMklDnnShape(
717           input_mkl_tensors[i].flat<uint8>().data(),
718           input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
719     }
720   } else {
721     for (int i = 0; i < mkl_shapes->size(); ++i) {
722       (*mkl_shapes)[i].SetMklTensor(false);
723     }
724   }
725 }
726 
727 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
728 /// If the input tensor is in MKL layout, then obtains TensorShape from
729 /// MklShape.
730 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
731                               bool eager_mode = false) {
732   // Sanity check.
733   CHECK_NOTNULL(context);
734   CHECK_LT(input_idx, context->num_inputs());
735 
736   MklDnnShape input_mkl_shape;
737   GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
738   if (input_mkl_shape.IsMklTensor() && !eager_mode) {
739     return input_mkl_shape.GetTfShape();
740   } else {
741     const Tensor& t = MklGetInput(context, input_idx);
742     return t.shape();
743   }
744 }
745 
746 // Allocate the second output tensor that will contain
747 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklDnnShape & mkl_shape)748 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
749                                       const MklDnnShape& mkl_shape) {
750   Tensor* second_tensor = nullptr;
751   TensorShape second_shape;
752   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
753   OP_REQUIRES_OK(ctext, ctext->allocate_output(
754                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
755                             second_shape, &second_tensor));
756   mkl_shape.SerializeMklDnnShape(
757       second_tensor->flat<uint8>().data(),
758       second_tensor->flat<uint8>().size() * sizeof(uint8));
759 }
760 
761 // Allocate the output tensor, create a second output tensor that will contain
762 // the MKL shape serialized
763 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
764                                       Tensor** output,
765                                       const TensorShape& tf_shape,
766                                       const MklDnnShape& mkl_shape,
767                                       bool eager_mode = false) {
768   OP_REQUIRES_OK(
769       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
770                                     tf_shape, output));
771   if (!eager_mode) {
772     Tensor* second_tensor = nullptr;
773     TensorShape second_shape;
774     second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
775     OP_REQUIRES_OK(ctext, ctext->allocate_output(
776                               GetTensorMetaDataIndex(n, ctext->num_outputs()),
777                               second_shape, &second_tensor));
778     mkl_shape.SerializeMklDnnShape(
779         second_tensor->flat<uint8>().data(),
780         second_tensor->flat<uint8>().size() * sizeof(uint8));
781   }
782 }
783 
784 // Allocates a temp tensor and returns the data buffer for temporary storage.
785 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,const memory::desc & pd,void ** buf_out)786 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
787                            const memory::desc& pd, void** buf_out) {
788   TensorShape tf_shape;
789 
790   tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
791   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
792                                                  tf_shape, tensor_out));
793   *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
794 }
795 
796 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,TensorShape tf_shape)797 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
798                            TensorShape tf_shape) {
799   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
800                                                  tf_shape, tensor_out));
801 }
802 
GetStridesFromSizes(MklTensorFormat data_format,size_t * strides,const size_t * sizes)803 inline void GetStridesFromSizes(MklTensorFormat data_format, size_t* strides,
804                                 const size_t* sizes) {
805   DCHECK_NE(data_format, MklTensorFormat::FORMAT_INVALID);
806   // MKL requires strides in NCHW
807   if (data_format == MklTensorFormat::FORMAT_NHWC) {
808     strides[0] = sizes[2];
809     strides[1] = sizes[0] * sizes[2];
810     strides[2] = 1;
811     strides[3] = sizes[0] * sizes[1] * sizes[2];
812   } else {
813     strides[0] = 1;
814     strides[1] = sizes[0];
815     strides[2] = sizes[0] * sizes[1];
816     strides[3] = sizes[0] * sizes[1] * sizes[2];
817   }
818 }
819 
CopyMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)820 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
821                                  int idx_out) {
822   int num_inputs = context->num_inputs();
823   int num_outputs = context->num_outputs();
824   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
825   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
826   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
827   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
828 
829   const Tensor& data = context->input(idx_data_in);
830   const Tensor& meta = context->input(idx_meta_in);
831   Tensor output(data.dtype());
832   Tensor meta_output(meta.dtype());
833 
834   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
835   CHECK(output.CopyFrom(data, data.shape()));
836   CHECK(meta_output.CopyFrom(meta, meta.shape()));
837   context->set_output(idx_data_out, output);
838   context->set_output(idx_meta_out, meta_output);
839 }
840 
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)841 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
842                                          int idx_out,
843                                          const TensorShape& shape) {
844   int num_inputs = context->num_inputs();
845   int num_outputs = context->num_outputs();
846   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
847   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
848 
849   const Tensor& data = context->input(idx_data_in);
850   MklDnnShape mkl_shape_output;
851   mkl_shape_output.SetMklTensor(false);
852   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
853   Tensor output(data.dtype());
854   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
855   CHECK(output.CopyFrom(data, shape));
856   context->set_output(idx_data_out, output);
857 }
858 
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)859 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
860                                    int idx_out) {
861   int num_inputs = context->num_inputs();
862   int num_outputs = context->num_outputs();
863   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
864   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
865 
866   MklDnnShape dnn_shape_output;
867   dnn_shape_output.SetMklTensor(false);
868   AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
869   if (IsRefType(context->input_dtype(idx_data_in))) {
870     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
871   } else {
872     context->set_output(idx_data_out, context->input(idx_data_in));
873   }
874 }
875 
ForwardMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)876 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
877                                     int idx_out) {
878   int num_inputs = context->num_inputs();
879   int num_outputs = context->num_outputs();
880   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
881   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
882   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
883   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
884 
885   if (IsRefType(context->input_dtype(idx_data_in))) {
886     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
887     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
888   } else {
889     context->set_output(idx_data_out, context->input(idx_data_in));
890     context->set_output(idx_meta_out, context->input(idx_meta_in));
891   }
892 }
893 
894 // Set a dummy MKLDNN shape (called when the output is in TF format)
SetDummyMklDnnShapeOutput(OpKernelContext * context,uint32 idx_data_out)895 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
896                                       uint32 idx_data_out) {
897   MklDnnShape mkl_shape_output;
898   mkl_shape_output.SetMklTensor(false);
899   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
900 }
901 
902 // If the input tensor has ref count as 1, it is forwarded to the desired
903 // output port and the function returns true. In that case, it also allocates
904 // the serialized MklDnnShape object. Otherwise, the function returns false.
905 inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
906                                                 int idx_in, int idx_out,
907                                                 Tensor** output,
908                                                 const MklDnnShape& mkl_shape,
909                                                 bool always_forward = true) {
910   int num_inputs = context->num_inputs();
911   int num_outputs = context->num_outputs();
912   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
913   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
914   bool is_forwarded = false;
915   const Tensor& input_tensor = context->input(idx_data_in);
916   const auto output_shape = input_tensor.shape();
917   if (always_forward) {
918     if (IsRefType(context->input_dtype(idx_data_in))) {
919       context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
920     } else {
921       context->set_output(idx_data_out, input_tensor);
922     }
923   } else {
924     is_forwarded = context->forward_input_to_output_with_shape(
925         idx_data_in, idx_data_out, output_shape, output);
926   }
927   if (is_forwarded || always_forward) {
928     AllocateOutputSetMklShape(context, idx_out, mkl_shape);
929     return true;
930   }
931   return false;
932 }
933 
934 // Forward the MKL shape ONLY (used in elementwise and other ops where
935 // we call the eigen implementation and MKL shape is not used)
ForwardMklMetaDataInToOut(OpKernelContext * context,uint32 idx_data_in,uint32_t idx_data_out)936 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
937                                       uint32 idx_data_in,
938                                       uint32_t idx_data_out) {
939   uint32 idx_meta_in =
940       GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
941   uint32 idx_meta_out =
942       GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
943 
944   if (IsRefType(context->input_dtype(idx_data_in))) {
945     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
946   } else {
947     context->set_output(idx_meta_out, context->input(idx_meta_in));
948   }
949 }
950 
951 // -------------------------------------------------------------------
952 //          Common utility functions used by MKL unit tests
953 
GetMklMetaTensor()954 inline Tensor GetMklMetaTensor() {
955   MklDnnShape non_mkl_shape;
956   non_mkl_shape.SetMklTensor(false);
957 
958   auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
959   Tensor tensor(DT_UINT8, {size});
960 
961   non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
962                                      size * sizeof(uint8));
963   return tensor;
964 }
965 
966 // -------------------------------------------------------------------
967 
968 /// Return MKL-DNN data type (memory::data_type) for input type T
969 ///
970 /// @input None
971 /// @return memory::data_type corresponding to type T
972 template <typename T>
973 static memory::data_type MklDnnType();
974 
975 /// Instantiation for float type. Add similar instantiations for other
976 /// type if needed.
977 template <>
978 memory::data_type MklDnnType<float>() {
979   return memory::data_type::f32;
980 }
981 
982 template <>
983 memory::data_type MklDnnType<quint8>() {
984   return memory::data_type::u8;
985 }
986 
987 template <>
988 memory::data_type MklDnnType<uint8>() {
989   return memory::data_type::u8;
990 }
991 
992 template <>
993 memory::data_type MklDnnType<qint8>() {
994   return memory::data_type::s8;
995 }
996 
997 template <>
998 memory::data_type MklDnnType<qint32>() {
999   return memory::data_type::s32;
1000 }
1001 template <>
1002 memory::data_type MklDnnType<bfloat16>() {
1003   return memory::data_type::bf16;
1004 }
1005 
1006 // Map MklTensorFormat to MKL-DNN format tag
1007 //
1008 // @input: MklTensorFormat i.e. TensorFlow data format
1009 // @return: MKL-DNN's memory format tag corresponding to MklTensorFormat.
1010 //          Fails with an error if invalid data format.
MklTensorFormatToMklDnnDataFormat(MklTensorFormat format)1011 inline memory::format_tag MklTensorFormatToMklDnnDataFormat(
1012     MklTensorFormat format) {
1013   if (format == MklTensorFormat::FORMAT_NHWC) return memory::format_tag::nhwc;
1014   if (format == MklTensorFormat::FORMAT_NCHW) return memory::format_tag::nchw;
1015   if (format == MklTensorFormat::FORMAT_NDHWC) return memory::format_tag::ndhwc;
1016   if (format == MklTensorFormat::FORMAT_NCDHW) return memory::format_tag::ncdhw;
1017   if (format == MklTensorFormat::FORMAT_X) return memory::format_tag::x;
1018   if (format == MklTensorFormat::FORMAT_NC) return memory::format_tag::nc;
1019   if (format == MklTensorFormat::FORMAT_TNC) return memory::format_tag::tnc;
1020   return memory::format_tag::undef;
1021 }
1022 
1023 /// Map TensorFlow data format into MKL-DNN 3D data format
1024 /// @input: TensorFlow data format
1025 /// @return: MKL-DNN 3D data format corresponding to TensorFlow data format;
1026 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnn3DDataFormat(TensorFormat format)1027 inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1028   if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC;
1029   if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW;
1030   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1031   return MklTensorFormat::FORMAT_INVALID;
1032 }
1033 
1034 /// Map TensorFlow data format into MKL-DNN data format
1035 ///
1036 /// @input: TensorFlow data format
1037 /// @return: MKL-DNN data format corresponding to TensorFlow data format;
1038 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnnDataFormat(TensorFormat format)1039 inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1040   if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC;
1041   if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW;
1042   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1043   return MklTensorFormat::FORMAT_INVALID;
1044 }
1045 
1046 /// Map MKL-DNN data format into TensorFlow data format
1047 ///
1048 /// @input: MKL-DNN data format
1049 /// @return: Tensorflow data format corresponding to MKL-DNN data format;
1050 ///          Fails with an error if invalid data format.
MklDnnDataFormatToTFDataFormat(MklTensorFormat format)1051 inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) {
1052   if (format == MklTensorFormat::FORMAT_NHWC ||
1053       format == MklTensorFormat::FORMAT_NDHWC)
1054     return FORMAT_NHWC;
1055   if (format == MklTensorFormat::FORMAT_NCHW ||
1056       format == MklTensorFormat::FORMAT_NCDHW)
1057     return FORMAT_NCHW;
1058   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1059 
1060   // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1061   // that we don't come here.
1062   return FORMAT_NHWC;
1063 }
1064 
1065 /// Map TensorShape object into memory::dims required by MKL-DNN
1066 ///
1067 /// This function will simply map input TensorShape into MKL-DNN dims
1068 /// naively. So it will preserve the order of dimensions. E.g., if
1069 /// input tensor is in NHWC format, then dims will be in NHWC format also.
1070 ///
1071 /// @input TensorShape object in shape
1072 /// @return memory::dims corresponding to TensorShape
TFShapeToMklDnnDims(const TensorShape & shape)1073 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1074   memory::dims dims(shape.dims());
1075   for (int d = 0; d < shape.dims(); ++d) {
1076     dims[d] = shape.dim_size(d);
1077   }
1078   return dims;
1079 }
1080 
1081 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
1082 ///
1083 /// This function is a specific one than above function. It will map input
1084 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
1085 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1086 /// will be in NCHW format, and not in NHWC format.
1087 ///
1088 /// @input TensorShape object in shape
1089 /// @return memory::dims in MKL-DNN required NCHW format
TFShapeToMklDnnDimsInNCHW(const TensorShape & shape,TensorFormat format)1090 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1091                                               TensorFormat format) {
1092   // Check validity of format.
1093   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1094             MklTensorFormat::FORMAT_INVALID);
1095 
1096   int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1097   int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1098   int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1099   int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1100 
1101   // MKL-DNN requires dimensions in NCHW format.
1102   return memory::dims({n, c, h, w});
1103 }
1104 
TFShapeToMklDnnDimsInNCDHW(const TensorShape & shape,TensorFormat format)1105 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1106                                                TensorFormat format) {
1107   // Validate format.
1108   DCHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1109             MklTensorFormat::FORMAT_INVALID);
1110 
1111   int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1112   int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1113   int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1114   int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1115   int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1116 
1117   // MKL-DNN requires dimensions in NCDHW format.
1118   return memory::dims({n, c, d, h, w});
1119 }
1120 
1121 /// Overloaded version of function TFShapeToMklDnnDimsInNCHW above.
1122 /// Input parameters are self-explanatory.
MklDnnDimsInNCHW(const memory::dims & in_dims,TensorFormat format)1123 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1124                                      TensorFormat format) {
1125   // Validate format.
1126   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1127             MklTensorFormat::FORMAT_INVALID);
1128 
1129   int n = in_dims[GetTensorDimIndex(format, 'N')];
1130   int c = in_dims[GetTensorDimIndex(format, 'C')];
1131   int h = in_dims[GetTensorDimIndex(format, 'H')];
1132   int w = in_dims[GetTensorDimIndex(format, 'W')];
1133 
1134   // MKL-DNN requires dimensions in NCHW format.
1135   return memory::dims({n, c, h, w});
1136 }
1137 
1138 /// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above.
1139 /// Input parameters are self-explanatory.
MklDnnDimsInNCDHW(const memory::dims & in_dims,TensorFormat format)1140 inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims,
1141                                       TensorFormat format) {
1142   // Validate format.
1143   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1144             MklTensorFormat::FORMAT_INVALID);
1145 
1146   int n = in_dims[GetTensorDimIndex<3>(format, 'N')];
1147   int c = in_dims[GetTensorDimIndex<3>(format, 'C')];
1148   int d = in_dims[GetTensorDimIndex<3>(format, '0')];
1149   int h = in_dims[GetTensorDimIndex<3>(format, '1')];
1150   int w = in_dims[GetTensorDimIndex<3>(format, '2')];
1151 
1152   // MKL DNN requires dimensions in NCDHW format.
1153   return memory::dims({n, c, d, h, w});
1154 }
1155 
1156 /// Map MklDnn memory::dims object into TensorShape object.
1157 ///
1158 /// This function will simply map input shape in MKL-DNN memory::dims format
1159 /// in Tensorflow's TensorShape object by preserving dimension order.
1160 ///
1161 /// @input MKL-DNN memory::dims object
1162 /// @output TensorShape corresponding to memory::dims
MklDnnDimsToTFShape(const memory::dims & dims)1163 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1164   std::vector<int32> shape(dims.size(), -1);
1165   for (int d = 0; d < dims.size(); d++) {
1166     shape[d] = dims[d];
1167   }
1168 
1169   TensorShape ret;
1170   CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1171   return ret;
1172 }
1173 
1174 /// Function to calculate strides given tensor shape in Tensorflow order
1175 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1176 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
1177 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1178 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1179 ///
1180 /// @input Tensorflow shape in memory::dims type
1181 /// @return memory::dims containing strides for the tensor.
CalculateTFStrides(const memory::dims & dims_tf_order)1182 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1183   CHECK_GT(dims_tf_order.size(), 0);
1184   memory::dims strides(dims_tf_order.size());
1185   int last_dim_idx = dims_tf_order.size() - 1;
1186   strides[last_dim_idx] = 1;
1187   for (int d = last_dim_idx - 1; d >= 0; d--) {
1188     strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1189   }
1190   return strides;
1191 }
1192 
1193 /// Helper function to create memory descriptor in Blocked format
1194 ///
1195 /// @input: Tensor dimensions
1196 /// @input: strides corresponding to dimensions. One can use utility
1197 ///         function such as CalculateTFStrides to compute strides
1198 ///         for given dimensions.
1199 /// @output: mkldnn_memory_desc_t object corresponding to blocked memory
1200 ///          format for given dimensions and strides.
1201 /// @return: Status indicating whether the blocked memory descriptor
1202 ///          was successfully created.
CreateBlockedMemDescHelper(const memory::dims & dim,const memory::dims & strides,memory::data_type dtype,mkldnn_memory_desc_t * blocked_md)1203 inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
1204                                          const memory::dims& strides,
1205                                          memory::data_type dtype,
1206                                          mkldnn_memory_desc_t* blocked_md) {
1207   DCHECK_EQ(dim.size(), strides.size());
1208   const int kNumDims = dim.size();
1209   mkldnn_dim_t* input_dims = new mkldnn_dim_t[kNumDims];
1210   mkldnn_dim_t* input_strides = new mkldnn_dim_t[kNumDims];
1211   for (int i = 0; i < kNumDims; ++i) {
1212     input_dims[i] = dim[i];
1213     input_strides[i] = strides[i];
1214   }
1215   try {
1216     mkldnn_memory_desc_init_by_strides(blocked_md, kNumDims, input_dims,
1217                                        memory::convert_to_c(dtype),
1218                                        input_strides);
1219     delete[] input_dims;
1220     delete[] input_strides;
1221   } catch (mkldnn::error& e) {
1222     delete[] input_dims;
1223     delete[] input_strides;
1224     return Status(error::Code::INTERNAL,
1225                   tensorflow::strings::StrCat(
1226                       "Failed to create blocked memory descriptor.",
1227                       "Status: ", e.status, ", message: ", e.message));
1228   }
1229   return Status::OK();
1230 }
1231 
1232 inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
1233                                     const memory& src_mem,
1234                                     const memory& dst_mem, const engine& engine,
1235                                     OpKernelContext* ctx = nullptr) {
1236   std::vector<primitive> net;
1237   net.push_back(mkldnn::reorder(reorder_desc));
1238   std::vector<MemoryArgsMap> net_args;
1239   net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
1240   ExecutePrimitive(net, &net_args, engine, ctx);
1241 }
1242 
1243 class MklReorderPrimitive;
1244 
1245 template <typename T>
1246 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1247                                                 const memory* to);
1248 
1249 // Class to represent all the resources corresponding to a tensor in TensorFlow
1250 // that are required to execute an operation (such as Convolution).
1251 template <typename T>
1252 class MklDnnData {
1253  private:
1254   /// MKL-DNN memory primitive for input user memory
1255   memory* user_memory_;
1256 
1257   /// MKL-DNN memory primitive in case input or output reorder is needed.
1258   memory* reorder_memory_;
1259 
1260   /// Operations memory descriptor
1261   memory::desc* op_md_;
1262   // flat to indicate if data is 3D or not.
1263   bool bIs3D;
1264   /// Operations temp buffer
1265   void* allocated_buffer_;
1266   /// CPU engine on which operation will be executed
1267   const engine* cpu_engine_;
1268 
1269  public:
MklDnnData(const engine * e)1270   explicit MklDnnData(const engine* e)
1271       : user_memory_(nullptr),
1272         reorder_memory_(nullptr),
1273         op_md_(nullptr),
1274         bIs3D(false),
1275         allocated_buffer_(nullptr),
1276         cpu_engine_(e) {}
1277 
~MklDnnData()1278   ~MklDnnData() {
1279     if (allocated_buffer_ != nullptr) {
1280       cpu_allocator()->DeallocateRaw(allocated_buffer_);
1281     }
1282     cpu_engine_ = nullptr;  // We don't own this.
1283     delete (user_memory_);
1284     delete (reorder_memory_);
1285     delete (op_md_);
1286   }
1287 
GetTensorBuffer(const Tensor * tensor)1288   inline void* GetTensorBuffer(const Tensor* tensor) const {
1289     CHECK_NOTNULL(tensor);
1290     return const_cast<void*>(
1291         static_cast<const void*>(tensor->flat<T>().data()));
1292   }
1293 
SetIs3DData(bool bIs3D_)1294   void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
GetIs3D()1295   bool GetIs3D() { return bIs3D; }
1296 
1297   /// Set user memory primitive using specified dimensions, memory format tag
1298   /// and data_buffer. Function automatically uses element data type by using
1299   /// input type T used for creating call object.
1300   ///
1301   /// In a nutshell, function allows user to describe the input tensor to
1302   /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1303   /// memory format tag HWIO, and the buffer that contains actual values is
1304   /// pointed by data_buffer.
1305   inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1306                         void* data_buffer = nullptr) {
1307     auto md = memory::desc(dim, MklDnnType<T>(), fm);
1308     SetUsrMem(md, data_buffer);
1309   }
1310 
SetUsrMem(const memory::dims & dim,memory::format_tag fm,const Tensor * tensor)1311   inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1312                         const Tensor* tensor) {
1313     DCHECK(tensor);
1314     SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1315   }
1316 
1317   /// Helper function to create memory descriptor in Blocked format
1318   ///
1319   /// @input: Tensor dimensions
1320   /// @input: strides corresponding to dimensions. One can use utility
1321   ///         function such as CalculateTFStrides to compute strides
1322   ///         for given dimensions.
1323   /// @return: memory::desc object corresponding to blocked memory format
1324   ///          for given dimensions and strides.
CreateBlockedMemDesc(const memory::dims & dim,const memory::dims & strides)1325   static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1326                                                   const memory::dims& strides) {
1327     mkldnn_memory_desc_t blocked_md;
1328     TF_CHECK_OK(
1329         CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>(), &blocked_md));
1330     return memory::desc(blocked_md);
1331   }
1332 
1333   /// A version of SetUsrMem call that allows user to create memory in blocked
1334   /// format. So in addition to accepting dimensions, it also accepts strides.
1335   /// This allows user to create memory for tensor in a format that is not
1336   /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
1337   /// dimensional tensor as a native format. But by using blocked format, a user
1338   /// can create memory for 6D tensor.
1339   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1340                         void* data_buffer = nullptr) {
1341     CHECK_EQ(dim.size(), strides.size());
1342     auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1343     SetUsrMem(blocked_md, data_buffer);
1344   }
1345 
SetUsrMem(const memory::dims & dim,const memory::dims & strides,const Tensor * tensor)1346   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1347                         const Tensor* tensor) {
1348     CHECK_NOTNULL(tensor);
1349     SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1350   }
1351 
1352   /// A version of SetUsrMem with memory descriptor and tensor
SetUsrMem(const memory::desc & md,const Tensor * tensor)1353   inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1354     CHECK_NOTNULL(tensor);
1355     SetUsrMem(md, GetTensorBuffer(tensor));
1356   }
1357 
1358   /// A version of function to set user memory type that accepts memory
1359   /// descriptor directly, instead of accepting dimensions and format. This
1360   /// function is more generic than the one above, but the function above is
1361   /// sufficient in most cases.
1362   inline void SetUsrMem(const memory::desc& pd, void* data_buffer = nullptr) {
1363     DCHECK(cpu_engine_);
1364     if (user_memory_) delete user_memory_;
1365     // TODO(nhasabni): can we remove dynamic memory allocation?
1366     if (data_buffer) {
1367       user_memory_ = new memory(pd, *cpu_engine_, data_buffer);
1368     } else {
1369       user_memory_ = new memory(pd, *cpu_engine_);
1370     }
1371   }
1372 
1373   /// Get function for user memory primitive.
GetUsrMem()1374   inline const memory* GetUsrMem() const { return user_memory_; }
1375 
1376   /// Get function for descriptor of user memory.
GetUsrMemDesc()1377   inline memory::desc GetUsrMemDesc() const {
1378     DCHECK(user_memory_);
1379     return user_memory_->get_desc();
1380   }
1381 
1382   /// Get function for data buffer of user memory primitive.
GetUsrMemDataHandle()1383   inline void* GetUsrMemDataHandle() const {
1384     CHECK_NOTNULL(user_memory_);
1385     return user_memory_->get_data_handle();
1386   }
1387 
1388   /// Set function for data buffer of user memory primitive.
1389   inline void SetUsrMemDataHandle(void* data_buffer,
1390                                   std::shared_ptr<stream> t_stream = nullptr) {
1391     CHECK_NOTNULL(user_memory_);
1392     CHECK_NOTNULL(data_buffer);
1393 #ifndef ENABLE_ONEDNN_OPENMP
1394     user_memory_->set_data_handle(data_buffer, *t_stream);
1395 #else
1396     user_memory_->set_data_handle(data_buffer);
1397 #endif  // !ENABLE_ONEDNN_OPENMP
1398   }
1399 
1400   /// Set function for data buffer of user memory primitive.
1401   inline void SetUsrMemDataHandle(const Tensor* tensor,
1402                                   std::shared_ptr<stream> t_stream = nullptr) {
1403     SetUsrMemDataHandle(GetTensorBuffer(tensor), t_stream);
1404   }
1405 
1406   /// allocate function for data buffer
AllocateBuffer(size_t size)1407   inline void AllocateBuffer(size_t size) {
1408     const int64 kMemoryAlignment = 64;  // For AVX512 memory alignment.
1409     allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlignment, size);
1410   }
1411 
GetAllocatedBuffer()1412   inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1413 
1414   /// Get the memory primitive for input and output of an op. If inputs
1415   /// to an op require reorders, then this function returns memory primitive
1416   /// for reorder. Otherwise, it will return memory primitive for user memory.
1417   ///
1418   /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1419   /// execute Conv2D, we need memory primitive for I and F. But if reorder is
1420   /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1421   /// primitive for F), then we need I_r and F_r to perform Conv2D.
GetOpMem()1422   inline const memory& GetOpMem() const {
1423     return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1424   }
1425 
1426   /// Set memory descriptor of an operation in terms of dimensions and memory
1427   /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1428   /// but memory::format_tag would be mkldnn::any because we want MKL-DNN to
1429   /// choose the best layout/format for given input dimensions.
SetOpMemDesc(const memory::dims & dim,memory::format_tag fm)1430   inline void SetOpMemDesc(const memory::dims& dim, memory::format_tag fm) {
1431     // TODO(nhasabni): can we remove dynamic memory allocation?
1432     op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1433   }
1434 
1435   /// Get function for memory descriptor for an operation
GetOpMemDesc()1436   inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1437 
1438   /// Predicate that checks if we need to reorder user's memory into memory
1439   /// pointed by op_md.
1440   ///
1441   /// @input: op_md - memory descriptor of the given input of an operation.
1442   /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::desc & op_pd)1443   inline bool IsReorderNeeded(const memory::desc& op_pd) const {
1444     DCHECK(user_memory_);
1445     return op_pd != user_memory_->get_desc();
1446   }
1447 
1448   /// Function to create a reorder from memory pointed by from to memory pointed
1449   /// by to. Returns created primitive.
CreateReorder(const memory * from,const memory * to)1450   inline primitive CreateReorder(const memory* from, const memory* to) const {
1451     CHECK_NOTNULL(from);
1452     CHECK_NOTNULL(to);
1453     return reorder(*from, *to);
1454   }
1455 
1456   /// Function to handle input reordering
1457   ///
1458   /// Check if we need to reorder this input of an operation.
1459   /// Return true and allocate reorder memory primitive if reorder is needed.
1460   /// Otherwise, return false and do not allocate reorder memory primitive.
1461   ///
1462   /// To check if reorder is needed, this function compares memory primitive
1463   /// descriptor (memory descriptor for v1.x) of an operation (op_pd) for
1464   /// the given input with the user-specified memory descriptor.
1465   ///
1466   /// @input: op_pd - memory primitive descriptor of the given input of an
1467   ///                 operation
1468   /// @input: net - net to which to add reorder primitive in case it is needed.
1469   /// @input: net_args - net to which user and reorder memories are added if
1470   ///                    needed. Each entry is a key-value pair of the form
1471   ///                    <argument-type, mkldnn::memory>.
1472   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1473   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1474                                   std::vector<primitive>& net,
1475                                   std::vector<MemoryArgsMap>& net_args,
1476                                   const engine& engine) {
1477     DCHECK(user_memory_);
1478     DCHECK_EQ(net.size(), net_args.size());
1479     if (IsReorderNeeded(op_md)) {
1480       // TODO(nhasabni): can we remove dynamic memory allocation?
1481       reorder_memory_ = new memory(op_md, engine);
1482       net.push_back(CreateReorder(user_memory_, reorder_memory_));
1483       net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1484                                        {MKLDNN_ARG_TO, *reorder_memory_}});
1485       return true;
1486     }
1487     return false;
1488   }
1489 
1490   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1491                                   const engine& engine,
1492                                   OpKernelContext* context = nullptr) {
1493     DCHECK(user_memory_);
1494     if (IsReorderNeeded(op_md)) {
1495       // TODO(nhasabni): can we remove dynamic memory allocation?
1496       // primitive reuse don't allow two same reorder prim in
1497       // one stream, so submit it immediately
1498       reorder_memory_ = new memory(op_md, engine);
1499       auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1500       std::shared_ptr<stream> cpu_stream;
1501       MklDnnThreadPool eigen_tp;
1502       if (context != nullptr) {
1503         eigen_tp = MklDnnThreadPool(context);
1504         cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1505       } else {
1506         cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1507       }
1508       std::vector<primitive> net;
1509       net.push_back(*(prim->GetPrimitive()));
1510       std::vector<MemoryArgsMap> net_args;
1511       net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1512                           {MKLDNN_ARG_TO, *reorder_memory_}});
1513       execute_primitives(net, cpu_stream, net_args);
1514       return true;
1515     }
1516     return false;
1517   }
1518 
1519   /// Overloaded version of above function that accepts memory buffer
1520   /// where output of reorder needs to be stored.
1521   ///
1522   /// @input: op_pd - memory primitive descriptor (memory descriptor for v1.x)
1523   ///                 of the given input of an operation
1524   /// @reorder_data_handle - memory buffer where output of reorder needs to be
1525   ///                        stored. Primitive does not check if buffer has
1526   ///                        enough size to write.
1527   /// @input: net - net to which to add reorder primitive in case it is needed.
1528   /// @input: net_args - net to which user and reorder memories are added if
1529   ///                    needed. Each entry is a key-value pair of the form
1530   ///                    <argument-type, mkldnn::memory>.
1531   /// @input: engine - MKL-DNN's abstraction of a computational device
1532   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,void * reorder_data_handle,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1533   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1534                                   void* reorder_data_handle,
1535                                   std::vector<primitive>& net,
1536                                   std::vector<MemoryArgsMap>& net_args,
1537                                   const engine& engine) {
1538     DCHECK(reorder_data_handle);
1539     DCHECK(user_memory_);
1540     if (IsReorderNeeded(op_md)) {
1541       // TODO(nhasabni): can we remove dynamic memory allocation?
1542       reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1543       net.push_back(CreateReorder(user_memory_, reorder_memory_));
1544       net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1545                                        {MKLDNN_ARG_TO, *reorder_memory_}});
1546       return true;
1547     }
1548     return false;
1549   }
1550 
1551   /// This is a faster path with reorder primitive cache compared with
1552   /// CheckReorderToOpMem(..., std::vector<primitive>* net).
1553   /// The slower path will be removed in the future
1554   /// TODO(bhavanis): Need to use reorder cache here for better performance.
1555   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1556                                   void* reorder_data_handle,
1557                                   const engine& engine,
1558                                   OpKernelContext* context = nullptr) {
1559     DCHECK(reorder_data_handle);
1560     DCHECK(user_memory_);
1561     if (IsReorderNeeded(op_md)) {
1562       // TODO(nhasabni): can we remove dynamic memory allocation?
1563       // primitive reuse don't allow two same reorder prim in
1564       // one stream, so submit it immediately
1565       reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1566       auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1567       std::shared_ptr<stream> cpu_stream;
1568       MklDnnThreadPool eigen_tp;
1569       if (context != nullptr) {
1570         eigen_tp = MklDnnThreadPool(context);
1571         cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1572       } else {
1573         cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1574       }
1575       std::vector<primitive> net;
1576       net.push_back(*(prim->GetPrimitive()));
1577       std::vector<MemoryArgsMap> net_args;
1578       net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1579                           {MKLDNN_ARG_TO, *reorder_memory_}});
1580       execute_primitives(net, cpu_stream, net_args);
1581       return true;
1582     }
1583     return false;
1584   }
1585 
1586   /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1587   /// where output of reorder needs to be stored.
1588   ///
1589   /// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
1590   ///                 of the given input of an operation
1591   /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1592   ///                   reorder. Primitive does not check if buffer is
1593   ///                   enough size to write.
1594   /// @input: net - net to which to add reorder primitive in case it is needed.
1595   /// @input: net_args - net to which user and reorder memories are added if
1596   ///                    needed. Each entry is a key-value pair of the form
1597   ///                    <argument-type, mkldnn::memory>.
1598   /// @input: engine - MKL-DNN's abstraction of a computational device
1599   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,Tensor * reorder_tensor,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1600   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1601                                   Tensor* reorder_tensor,
1602                                   std::vector<primitive>& net,
1603                                   std::vector<MemoryArgsMap>& net_args,
1604                                   const engine& engine) {
1605     DCHECK(reorder_tensor);
1606     return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
1607                                net_args, engine);
1608   }
1609 
1610   /// TODO: this is a faster path with reorder primitive cache compared with
1611   /// CheckReorderToOpMem(op_md, reorder_tensor, net, net_args, engine), will
1612   /// remove
1613   /// slow path in the future
1614   inline bool CheckReorderToOpMem(const memory::desc& op_pd,
1615                                   Tensor* reorder_tensor,
1616                                   OpKernelContext* ctx = nullptr) {
1617     DCHECK(reorder_tensor);
1618     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
1619                                *cpu_engine_, ctx);
1620   }
1621 
1622   /// Function to handle output reorder
1623   ///
1624   /// This function performs very similar functionality as input reordering
1625   /// function above. The only difference is that this function does not add
1626   /// reorder primitive to the net. The reason for this is: the reorder
1627   /// primitive for output needs to be added to the list only after operation
1628   /// has executed. But we need to prepare a temporary buffer in case output
1629   /// reorder is needed. And this temporary buffer will hold the output of
1630   /// an operation before it is fed to reorder primitive.
1631   ///
1632   /// @input - memory primitive descriptor (memory descriptor for v1.x) for the
1633   ///          given output of an operation
1634   /// @return: true in case reorder of output is needed; false, otherwise.
PrepareReorderToUserMemIfReq(const memory::desc & op_pd)1635   inline bool PrepareReorderToUserMemIfReq(const memory::desc& op_pd) {
1636     DCHECK(user_memory_);
1637     if (IsReorderNeeded(op_pd)) {
1638       // TODO(nhasabni): can we remove dynamic memory allocation?
1639       reorder_memory_ = new memory(op_pd, *cpu_engine_);
1640       return true;
1641     }
1642     return false;
1643   }
1644 
1645   /// Function to actually insert reorder primitive in the net
1646   ///
1647   /// This function completes remaining part of output reordering. It inserts
1648   /// a reordering primitive from the temporary buffer that holds the output
1649   /// to the user-specified output buffer.
1650   ///
1651   /// @input: net - net to which to add reorder primitive
1652   /// @input: net_args - net to which user and reorder memories are added if
1653   ///                    needed. Each entry is a key-value pair of the form
1654   ///                    <argument-type, mkldnn::memory>.
InsertReorderToUserMem(std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args)1655   inline void InsertReorderToUserMem(std::vector<primitive>& net,
1656                                      std::vector<MemoryArgsMap>& net_args) {
1657     DCHECK(user_memory_);
1658     DCHECK(reorder_memory_);
1659     net.push_back(CreateReorder(reorder_memory_, user_memory_));
1660     net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
1661                                      {MKLDNN_ARG_TO, *user_memory_}});
1662   }
1663 
1664   /// TODO: this is a faster path with reorder primitive cache compared with
1665   ///       InsertReorderToUserMem(net, net_args), will remove
1666   ///       slow path in the future
1667   inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
1668     DCHECK(user_memory_);
1669     DCHECK(reorder_memory_);
1670     DCHECK(cpu_engine_);
1671     // primitive reuse don't allow two same reorder prim in
1672     // one stream, so submit it immediately
1673     std::vector<primitive> net;
1674     auto* prim = FindOrCreateReorder<T>(reorder_memory_, user_memory_);
1675     net.push_back(*(prim->GetPrimitive()));
1676     std::vector<MemoryArgsMap> net_args;
1677     net_args.push_back(
1678         {{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
1679     std::shared_ptr<stream> cpu_stream;
1680     MklDnnThreadPool eigen_tp;
1681     if (ctx != nullptr) {
1682       eigen_tp = MklDnnThreadPool(ctx);
1683       cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1684     } else {
1685       cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1686     }
1687     execute_primitives(net, cpu_stream, net_args);
1688   }
1689 };
1690 
1691 /// Base class for operations with reuse of primitives
1692 class MklPrimitive {
1693  public:
~MklPrimitive()1694   virtual ~MklPrimitive() {}
MklPrimitive()1695   MklPrimitive() {}
MklPrimitive(const engine & cpu_engine)1696   MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
1697   // Dummy data which MKL DNN never operates on
1698   unsigned char* DummyData = nullptr;
1699   engine cpu_engine_ = engine(engine::kind::cpu, 0);
GetEngine()1700   const engine& GetEngine() { return cpu_engine_; }
1701 };
1702 
1703 const mkldnn::memory::dims NONE_DIMS = {};
1704 
1705 //
1706 // LRUCache is a class which implements LRU (Least Recently Used) cache.
1707 // The implementation is similar to that of
1708 //    tensorflow/core/platform/cloud/expiring_lru_cache.h
1709 // without its thread-safe part because the cache is supposed to be
1710 // used as thread local (for instance, MklPrimitive caching).
1711 //
1712 // The LRU list maintains objects in chronological order based on
1713 // creation time, with the least recently accessed object at the
1714 // tail of LRU list, while the most recently accessed object
1715 // at the head of LRU list.
1716 //
1717 // This class is used to maintain an upper bound on the total number of
1718 // cached items. When the cache reaches its capacity, the LRU item will
1719 // be removed and replaced by a new one from SetOp call.
1720 //
1721 template <typename T>
1722 class LRUCache {
1723  public:
LRUCache(size_t capacity)1724   explicit LRUCache(size_t capacity) {
1725     capacity_ = capacity;
1726     Clear();
1727   }
1728 
GetOp(const string & key)1729   T* GetOp(const string& key) {
1730     auto it = cache_.find(key);
1731     if (it == cache_.end()) {
1732       return nullptr;
1733     }
1734 
1735     // Move to the front of LRU list as the most recently accessed.
1736     lru_list_.erase(it->second.lru_iterator);
1737     lru_list_.push_front(it->first);
1738     it->second.lru_iterator = lru_list_.begin();
1739     return it->second.op;
1740   }
1741 
SetOp(const string & key,T * op)1742   void SetOp(const string& key, T* op) {
1743     if (lru_list_.size() >= capacity_) {
1744       Delete();
1745     }
1746 
1747     // Insert an entry to the front of the LRU list
1748     lru_list_.push_front(key);
1749     Entry entry(op, lru_list_.begin());
1750     cache_.emplace(std::make_pair(key, std::move(entry)));
1751   }
1752 
Clear()1753   void Clear() {
1754     if (lru_list_.empty()) return;
1755 
1756     // Clean up the cache
1757     cache_.clear();
1758     lru_list_.clear();
1759   }
1760 
1761  private:
1762   struct Entry {
1763     // The entry's value.
1764     T* op;
1765 
1766     // A list iterator pointing to the entry's position in the LRU list.
1767     std::list<string>::iterator lru_iterator;
1768 
1769     // Constructor
EntryEntry1770     Entry(T* op, std::list<string>::iterator it) {
1771       this->op = op;
1772       this->lru_iterator = it;
1773     }
1774 
1775     // Move constructor
EntryEntry1776     Entry(Entry&& source) noexcept
1777         : lru_iterator(std::move(source.lru_iterator)) {
1778       op = std::move(source.op);
1779       source.op = std::forward<T*>(nullptr);
1780     }
1781 
1782     // Destructor
~EntryEntry1783     ~Entry() {
1784       if (op != nullptr) delete op;
1785     }
1786   };
1787 
1788   // Remove the least recently accessed entry from LRU list, which
1789   // is the tail of lru_list_. Update cache_ correspondingly.
Delete()1790   bool Delete() {
1791     if (lru_list_.empty()) return false;
1792     string key = lru_list_.back();
1793     lru_list_.pop_back();
1794     cache_.erase(key);
1795     return true;
1796   }
1797 
1798   // Cache capacity
1799   size_t capacity_;
1800 
1801   // The cache, a map from string key to a LRU entry.
1802   std::unordered_map<string, Entry> cache_;
1803 
1804   // The LRU list of entries.
1805   // The front of the list contains the key of the most recently accessed
1806   // entry, while the back of the list is the least recently accessed entry.
1807   std::list<string> lru_list_;
1808 };
1809 
1810 template <typename T>
1811 class MklPrimitiveFactory {
1812  public:
MklPrimitiveFactory()1813   MklPrimitiveFactory() {}
1814 
~MklPrimitiveFactory()1815   ~MklPrimitiveFactory() {}
1816 
GetOp(const string & key)1817   MklPrimitive* GetOp(const string& key) {
1818     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1819     return lru_cache.GetOp(key);
1820   }
1821 
SetOp(const string & key,MklPrimitive * op)1822   void SetOp(const string& key, MklPrimitive* op) {
1823     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1824     lru_cache.SetOp(key, op);
1825   }
1826 
1827   /// Function to decide whether HW has AVX512 or AVX2
1828   /// For those legacy device(w/o AVX512 and AVX2),
1829   /// MKL-DNN GEMM will be used.
IsLegacyPlatform()1830   static inline bool IsLegacyPlatform() {
1831     static const bool is_legacy_platform =
1832         (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
1833          !port::TestCPUFeature(port::CPUFeature::AVX2));
1834     return is_legacy_platform;
1835   }
1836 
1837   /// Function to check whether primitive memory optimization is enabled
IsPrimitiveMemOptEnabled()1838   static inline bool IsPrimitiveMemOptEnabled() {
1839     static const bool is_primitive_mem_opt_enabled = [] {
1840       bool value = true;
1841       TF_CHECK_OK(
1842           ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true, &value));
1843       return value;
1844     }();
1845     return is_primitive_mem_opt_enabled;
1846   }
1847 
1848  private:
GetLRUCache()1849   static inline LRUCache<MklPrimitive>& GetLRUCache() {
1850     static const int kCapacity = 1024;  // cache capacity
1851     static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
1852     return lru_cache_;
1853   }
1854 };
1855 
1856 // utility class for creating keys of MKL primitive pool.
1857 class FactoryKeyCreator {
1858  public:
FactoryKeyCreator()1859   FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
1860 
~FactoryKeyCreator()1861   ~FactoryKeyCreator() {}
1862 
AddAsKey(const string & str)1863   void AddAsKey(const string& str) { Append(str); }
1864 
AddAsKey(const mkldnn::memory::dims & dims)1865   void AddAsKey(const mkldnn::memory::dims& dims) {
1866     for (unsigned int i = 0; i < dims.size(); i++) {
1867       AddAsKey<int>(dims[i]);
1868     }
1869   }
1870 
1871   template <typename T>
AddAsKey(const T data)1872   void AddAsKey(const T data) {
1873     auto buffer = reinterpret_cast<const char*>(&data);
1874     Append(StringPiece(buffer, sizeof(T)));
1875   }
1876 
1877   // generalisation to handle pointers
AddAsKey(const void * data)1878   void AddAsKey(const void* data) {
1879     auto buffer = reinterpret_cast<const char*>(&data);
1880     Append(StringPiece(buffer, sizeof(data)));
1881   }
1882 
GetKey()1883   string GetKey() { return key_; }
1884 
1885  private:
1886   string key_;
1887   const char delimiter = 'x';
1888   const int kMaxKeyLength = 256;
Append(StringPiece s)1889   void Append(StringPiece s) {
1890     key_.append(string(s));
1891     key_.append(1, delimiter);
1892   }
1893 };
1894 
1895 class MklReorderPrimitive : public MklPrimitive {
1896  public:
MklReorderPrimitive(const memory * from,const memory * to)1897   explicit MklReorderPrimitive(const memory* from, const memory* to)
1898       : MklPrimitive(engine(engine::kind::cpu, 0)) {
1899     Setup(from, to);
1900   }
~MklReorderPrimitive()1901   ~MklReorderPrimitive() {}
1902 
GetPrimitive()1903   std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
1904 
SetMemory(const memory * from,const memory * to)1905   void SetMemory(const memory* from, const memory* to) {
1906     context_.src_mem->set_data_handle(from->get_data_handle());
1907     context_.dst_mem->set_data_handle(to->get_data_handle());
1908   }
1909 
GetStream()1910   std::shared_ptr<mkldnn::stream> GetStream() { return stream_; }
1911 
1912  private:
1913   struct ReorderContext {
1914     std::shared_ptr<mkldnn::memory> src_mem;
1915     std::shared_ptr<mkldnn::memory> dst_mem;
1916     std::shared_ptr<primitive> reorder_prim;
ReorderContextReorderContext1917     ReorderContext()
1918         : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
1919   } context_;
1920 
1921   std::shared_ptr<mkldnn::stream> stream_;
1922 
Setup(const memory * from,const memory * to)1923   void Setup(const memory* from, const memory* to) {
1924     context_.src_mem.reset(
1925         new memory(from->get_desc(), cpu_engine_, DummyData));
1926     context_.dst_mem.reset(new memory(to->get_desc(), cpu_engine_, DummyData));
1927     context_.reorder_prim = std::make_shared<mkldnn::reorder>(
1928         reorder(*context_.src_mem, *context_.dst_mem));
1929     stream_.reset(new stream(cpu_engine_));
1930   }
1931 };
1932 
1933 template <typename T>
1934 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
1935  public:
Get(const memory * from,const memory * to)1936   static MklReorderPrimitive* Get(const memory* from, const memory* to) {
1937     auto reorderPrim = static_cast<MklReorderPrimitive*>(
1938         MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
1939     if (reorderPrim == nullptr) {
1940       reorderPrim = new MklReorderPrimitive(from, to);
1941       MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
1942                                                               reorderPrim);
1943     }
1944     reorderPrim->SetMemory(from, to);
1945     return reorderPrim;
1946   }
1947 
GetInstance()1948   static MklReorderPrimitiveFactory& GetInstance() {
1949     static MklReorderPrimitiveFactory instance_;
1950     return instance_;
1951   }
1952 
CreateKey(const memory * from,const memory * to)1953   static string CreateKey(const memory* from, const memory* to) {
1954     string prefix = "reorder";
1955     FactoryKeyCreator key_creator;
1956     auto const& from_desc = from->get_desc().data;
1957     auto const& to_desc = to->get_desc().data;
1958     memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
1959     memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
1960     auto from_strides = from_desc.format_desc.blocking.strides;
1961 
1962     // As DNNL memory desc has C style array and only init the used
1963     // part, so need use the valid part as key.
1964     auto from_inner_nblks = from_desc.format_desc.blocking.inner_nblks;
1965     auto from_inner_blks = from_desc.format_desc.blocking.inner_blks;
1966     auto from_inner_idxs = from_desc.format_desc.blocking.inner_idxs;
1967     memory::dims from_inner_blks_1(from_inner_blks,
1968                                    &from_inner_blks[from_inner_nblks]);
1969     memory::dims from_inner_idxs_1(from_inner_idxs,
1970                                    &from_inner_idxs[from_inner_nblks]);
1971     auto to_inner_nblks = to_desc.format_desc.blocking.inner_nblks;
1972     auto to_inner_blks = to_desc.format_desc.blocking.inner_blks;
1973     auto to_inner_idxs = to_desc.format_desc.blocking.inner_idxs;
1974     memory::dims to_inner_blks_1(to_inner_blks, &to_inner_blks[to_inner_nblks]);
1975     memory::dims to_inner_idxs_1(to_inner_idxs, &to_inner_idxs[to_inner_nblks]);
1976 
1977     auto to_strides = to_desc.format_desc.blocking.strides;
1978     memory::dims from_strides_outer_blocks(from_strides,
1979                                            &from_strides[from_desc.ndims]);
1980     memory::dims to_strides_outer_blocks(to_strides,
1981                                          &to_strides[to_desc.ndims]);
1982 
1983     key_creator.AddAsKey(prefix);
1984     key_creator.AddAsKey(static_cast<int>(from_desc.extra.flags));
1985     key_creator.AddAsKey(static_cast<int>(from_inner_nblks));
1986     key_creator.AddAsKey(from_inner_blks_1);
1987     key_creator.AddAsKey(from_inner_idxs_1);
1988     key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
1989     key_creator.AddAsKey(from_dims);
1990     key_creator.AddAsKey(from_strides_outer_blocks);
1991     key_creator.AddAsKey(static_cast<int>(to_desc.extra.flags));
1992     key_creator.AddAsKey(static_cast<int>(to_inner_nblks));
1993     key_creator.AddAsKey(to_inner_blks_1);
1994     key_creator.AddAsKey(to_inner_idxs_1);
1995     key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
1996     key_creator.AddAsKey(to_dims);
1997     key_creator.AddAsKey(to_strides_outer_blocks);
1998     return key_creator.GetKey();
1999   }
2000 
2001  private:
MklReorderPrimitiveFactory()2002   MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory()2003   ~MklReorderPrimitiveFactory() {}
2004 
GetReorder(const memory * from,const memory * to)2005   MklPrimitive* GetReorder(const memory* from, const memory* to) {
2006     string key = CreateKey(from, to);
2007     return this->GetOp(key);
2008   }
2009 
SetReorder(const memory * from,const memory * to,MklPrimitive * op)2010   void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
2011     string key = CreateKey(from, to);
2012     this->SetOp(key, op);
2013   }
2014 };
2015 
2016 /// Function to find(or create) a reorder from memory pointed by
2017 /// from to memory pointed by to, it will created primitive or
2018 /// get primitive from pool if it is cached.
2019 /// Returns the primitive.
2020 template <typename T>
FindOrCreateReorder(const memory * from,const memory * to)2021 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
2022                                                 const memory* to) {
2023   CHECK_NOTNULL(from);
2024   CHECK_NOTNULL(to);
2025   MklReorderPrimitive* reorder_prim =
2026       MklReorderPrimitiveFactory<T>::Get(from, to);
2027   return reorder_prim;
2028 }
2029 
2030 // utility function to determine if it is conv 1x1 and stride != 1
2031 // for purpose of temporarily disabling primitive reuse
IsConv1x1StrideNot1(memory::dims filter_dims,memory::dims strides)2032 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2033                                 memory::dims strides) {
2034   if (filter_dims.size() != 4 || strides.size() != 2) return false;
2035 
2036   return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2037           ((strides[0] != 1) || (strides[1] != 1)));
2038 }
2039 
2040 }  // namespace tensorflow
2041 
2042 /////////////////////////////////////////////////////////////////////
2043 // Macros for handling registration for various types
2044 /////////////////////////////////////////////////////////////////////
2045 
2046 #define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
2047 
2048 #define REGISTER_TEST_BFLOAT16(TEST) \
2049   REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
2050 
2051 #define REGISTER_TEST_ALL_TYPES(TEST) \
2052   REGISTER_TEST_FLOAT32(TEST);        \
2053   REGISTER_TEST_BFLOAT16(TEST);
2054 #else
2055 #define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
2056 
2057 #endif  // INTEL_MKL
2058 #endif  // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2059