• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18 #ifdef INTEL_MKL
19 
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "mkldnn.hpp"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/graph/mkl_graph_util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/util/env_var.h"
39 #include "tensorflow/core/util/mkl_threadpool.h"
40 #include "tensorflow/core/util/mkl_types.h"
41 #include "tensorflow/core/util/padding.h"
42 #include "tensorflow/core/util/tensor_format.h"
43 
44 using mkldnn::engine;
45 using mkldnn::memory;
46 using mkldnn::primitive;
47 using mkldnn::reorder;
48 using mkldnn::stream;
49 using CPUDevice = Eigen::ThreadPoolDevice;
50 using MemoryArgsMap = std::unordered_map<int, memory>;
51 using ReorderPd = mkldnn::reorder::primitive_desc;
52 
53 #ifdef _WIN32
54 typedef unsigned int uint;
55 #endif
56 
57 namespace tensorflow {
58 
59 // The file contains a number of utility classes and functions used by MKL
60 // enabled kernels
61 
62 // This class encapsulates all the meta data that is associated with an MKL
63 // tensor. A tensor is an MKL tensor if it was created as the result of an
64 // MKL operation, and did not go through a conversion to a standard
65 // Tensorflow tensor.
66 
67 // The dimensions order that MKL-DNN internally uses for 2D activations
68 // [Batch, Channel, Height, Width] and
69 // for 2D filters [Out_Channel, In_Channel, Height, Width].
70 typedef enum {
71   Dim_N = 0,
72   Dim_C = 1,
73   Dim_H = 2,
74   Dim_W = 3,
75   Dim_O = 0,
76   Dim_I = 1
77 } MklDnnDims;
78 
79 // The dimensions order that MKL-DNN internally uses for 3D activations
80 // [Batch, Channel, Depth, Height, Width] and
81 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
82 typedef enum {
83   Dim3d_N = 0,
84   Dim3d_C = 1,
85   Dim3d_D = 2,
86   Dim3d_H = 3,
87   Dim3d_W = 4,
88   Dim3d_O = 0,
89   Dim3d_I = 1
90 } MklDnnDims3D;
91 
92 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
93 // filter_width, in_channels, out_channels]
94 typedef enum {
95   TF_2DFILTER_DIM_H = 0,
96   TF_2DFILTER_DIM_W = 1,
97   TF_2DFILTER_DIM_I = 2,
98   TF_2DFILTER_DIM_O = 3
99 } TFFilterDims2d;
100 
101 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
102 // filter_height, filter_width, in_channels, out_channels]
103 typedef enum {
104   TF_3DFILTER_DIM_P = 0,
105   TF_3DFILTER_DIM_H = 1,
106   TF_3DFILTER_DIM_W = 2,
107   TF_3DFILTER_DIM_I = 3,
108   TF_3DFILTER_DIM_O = 4
109 } TFFilterDims3d;
110 
111 // The dimensions order that MKL-DNN requires for the filter in a grouped
112 // convolution (2D only)
113 typedef enum {
114   MKL_GROUP_FILTER_DIM_G = 0,
115   MKL_GROUP_FILTER_DIM_O = 1,
116   MKL_GROUP_FILTER_DIM_I = 2,
117   MKL_GROUP_FILTER_DIM_H = 3,
118   MKL_GROUP_FILTER_DIM_W = 4
119 } MklDnnFilterGroupDims;
120 
121 // Enum used to templatize MklOp kernel implementation
122 // that support both fp32 and int8 versions.
123 enum class MklQuantization {
124   QUANTIZED_VERSION,
125   FP_VERSION,
126 };
127 
128 static const int kSmallBatchSize = 32;
129 
execute_primitives(std::vector<mkldnn::primitive> & primitives,std::shared_ptr<stream> stream,std::vector<std::unordered_map<int,memory>> & net_args)130 inline void execute_primitives(
131     std::vector<mkldnn::primitive>& primitives, std::shared_ptr<stream> stream,
132     std::vector<std::unordered_map<int, memory>>& net_args) {
133   DCHECK_EQ(primitives.size(), net_args.size());
134   for (size_t i = 0; i < primitives.size(); ++i) {
135     primitives.at(i).execute(*stream, net_args.at(i));
136   }
137 }
138 
139 // In MKL-DNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor
140 // (md) structure will no longer be recorded in its `format` field. Instead, it
141 // will be set to a canonical `blocked` format for every fully described md.
142 //
143 // Currently, we query this `format` field while mapping MKL-DNN's data format
144 // to TF's data format. Due to the above restriction, we will now get this data
145 // format information from TF's `data_format` attribute (i.e. via
146 // `TensorFormat`) for MKL-DNN v1.x.
147 //
148 // Some MKL-DNN operators such as ReLU do not have a `data_format` attribute
149 // since they are usually in `blocked` format. Therefore, in order to
150 // distinguish between blocked and non-blocked formats, we have defined a new
151 // enum called `MklTensorFormat` that is semantically similar to `TensorFormat`
152 // but with the following additional fields namely:
153 //  1) FORMAT_BLOCKED: as described above, this is needed for element-wise
154 //     operators such as ReLU.
155 //  2) FORMAT_INVALID: for error-checking (ex. unsupported format)
156 //  3) FORMAT_X, FORMAT_NC, FORMAT_TNC: to distinguish between MKL tensors based
157 //     on their dimensions in operators such as Softmax, i.e.:
158 //        FORMAT_X   - 1D tensor
159 //        FORMAT_NC  - 2D tensor
160 //        FORMAT_TNC - 3D tensor
161 enum class MklTensorFormat {
162   FORMAT_NHWC = 0,
163   FORMAT_NCHW = 1,
164   FORMAT_NDHWC = 2,
165   FORMAT_NCDHW = 3,
166   FORMAT_X = 4,
167   FORMAT_NC = 5,
168   FORMAT_TNC = 6,
169   FORMAT_BLOCKED = 7,
170   FORMAT_INVALID = 8,
171 };
172 
173 // Forward declarations
174 memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
175 
176 TensorFormat MklDnn3DDataFormatToTFDataFormat(MklTensorFormat format);
177 TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format);
178 
179 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
180 Status CreateBlockedMemDescHelper(const memory::dims& dim,
181                                   const memory::dims& strides,
182                                   memory::data_type dtype,
183                                   mkldnn_memory_desc_t* blocked_md);
184 
185 inline std::ostream& operator<<(std::ostream& os,
186                                 const memory::format_tag& tag) {
187   if (tag == memory::format_tag::undef) {
188     os << "undef";
189   } else if (tag == memory::format_tag::any) {
190     os << "any";
191   } else {
192     os << "invalid";
193   }
194   return os;
195 }
196 
197 inline void operator<<(std::ostream& os, const MklTensorFormat& format) {
198   if (format == MklTensorFormat::FORMAT_NHWC) {
199     os << "FORMAT_NHWC";
200   } else if (format == MklTensorFormat::FORMAT_NCHW) {
201     os << "FORMAT_NCHW";
202   } else if (format == MklTensorFormat::FORMAT_NDHWC) {
203     os << "FORMAT_NDHWC";
204   } else if (format == MklTensorFormat::FORMAT_NCDHW) {
205     os << "FORMAT_NCDHW";
206   } else if (format == MklTensorFormat::FORMAT_X) {
207     os << "FORMAT_X";
208   } else if (format == MklTensorFormat::FORMAT_NC) {
209     os << "FORMAT_NC";
210   } else if (format == MklTensorFormat::FORMAT_TNC) {
211     os << "FORMAT_TNC";
212   } else if (format == MklTensorFormat::FORMAT_BLOCKED) {
213     os << "FORMAT_BLOCKED";
214   } else {
215     os << "INVALID FORMAT";
216   }
217 }
218 
219 template <typename T>
array_cmp(const T * a1,const T * a2,size_t size)220 inline bool array_cmp(const T* a1, const T* a2, size_t size) {
221   for (size_t i = 0; i < size; ++i)
222     if (a1[i] != a2[i]) return false;
223   return true;
224 }
225 
CreateStream(OpKernelContext * ctx,const engine & engine)226 inline mkldnn::stream* CreateStream(OpKernelContext* ctx,
227                                     const engine& engine) {
228 #ifdef ENABLE_MKLDNN_THREADPOOL
229   stream_attr tp_stream_attr(engine::kind::cpu);
230   if (ctx != nullptr) {
231     auto eigen_tp =
232         MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);
233     tp_stream_attr.set_threadpool(eigen_tp);
234     stream* tp_stream =
235         new stream(engine, stream::flags::default_flags, tp_stream_attr);
236     return tp_stream;
237   } else {
238     stream* tp_stream = new stream(engine);
239     return tp_stream;
240   }
241 #else
242   stream* tp_stream = new stream(engine);
243   return tp_stream;
244 #endif  // ENABLE_MKLDNN_THREADPOOL
245 }
246 
247 class MklDnnShape {
248  private:
249   struct MklShapeData {
250     // Flag to indicate if the tensor is an MKL tensor or not
251     bool is_mkl_tensor_ = false;
252     // Number of dimensions in Tensorflow format
253     size_t dimension_ = 0;
254     mkldnn_dims_t sizes_;  // Required by MKL for conversions
255     MklTensorFormat tf_data_format_ = MklTensorFormat::FORMAT_BLOCKED;
256     memory::data_type T_ = memory::data_type::undef;
257     // MKL layout
258     mkldnn_memory_desc_t mkl_md_;
259     /// TF dimension corresponding to this MKL dimension
260     mkldnn_dims_t map_;
261   };
262   MklShapeData data_;
263 
264   typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
265 
266 #define INVALID_DIM_SIZE -1
267 
268  public:
MklDnnShape()269   MklDnnShape() {
270     for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
271          ++i) {
272       data_.sizes_[i] = -1;
273     }
274     for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
275       data_.map_[i] = -1;
276     }
277   }
278 
~MklDnnShape()279   ~MklDnnShape() {}
280   TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape);  // Cannot copy
281 
282   /// Equality function for MklDnnShape objects
283   /// @return true if both are equal; false otherwise.
284   inline bool operator==(const MklDnnShape& input_shape) const {
285     if (this->IsMklTensor() != input_shape.IsMklTensor()) {
286       return false;
287     }
288 
289     // If input tensors are in MKL layout, then we check for dimensions and
290     // sizes.
291     if (this->IsMklTensor()) {
292       const mkldnn_memory_desc_t& cur_md = (this->GetMklLayout()).data;
293       const mkldnn_memory_desc_t& input_shape_md =
294           input_shape.GetMklLayout().data;
295       return this->GetTfShape() == input_shape.GetTfShape() &&
296              mkldnn_memory_desc_equal(&cur_md, &input_shape_md);
297     }
298 
299     // Both inputs are not MKL tensors.
300     return true;
301   }
302 
303   /// Equality operator for MklDnnShape and TFShape.
304   /// Returns: true if TF shapes for both are the same, false otherwise
305   inline bool operator==(const TensorShape& input_shape) const {
306     if (!this->IsMklTensor()) {
307       return false;
308     }
309 
310     return this->GetTfShape() == input_shape;
311   }
312 
IsMklTensor()313   inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
SetMklTensor(bool is_mkl_tensor)314   inline void SetMklTensor(bool is_mkl_tensor) {
315     data_.is_mkl_tensor_ = is_mkl_tensor;
316   }
317 
SetDimensions(const size_t dimension)318   inline void SetDimensions(const size_t dimension) {
319     data_.dimension_ = dimension;
320   }
GetDimension(char dimension)321   inline size_t GetDimension(char dimension) const {
322     int index = GetMklDnnTensorDimIndex(dimension);
323     CHECK(index >= 0 && index < this->GetDimension())
324         << "Invalid index from the dimension: " << index << ", " << dimension;
325     return this->DimSize(index);
326   }
327 
GetDimension3D(char dimension)328   inline size_t GetDimension3D(char dimension) const {
329     int index = GetMklDnnTensor3DDimIndex(dimension);
330     CHECK(index >= 0 && index < this->GetDimension())
331         << "Invalid index from the dimension: " << index << ", " << dimension;
332     return this->DimSize(index);
333   }
334 
GetMklDnnTensorDimIndex(char dimension)335   inline int32 GetMklDnnTensorDimIndex(char dimension) const {
336     switch (dimension) {
337       case 'N':
338         return MklDnnDims::Dim_N;
339       case 'C':
340         return MklDnnDims::Dim_C;
341       case 'H':
342         return MklDnnDims::Dim_H;
343       case 'W':
344         return MklDnnDims::Dim_W;
345       default:
346         LOG(FATAL) << "Invalid dimension: " << dimension;
347         return -1;  // Avoid compiler warning about missing return value
348     }
349   }
350 
GetMklDnnTensor3DDimIndex(char dimension)351   inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
352     switch (dimension) {
353       case 'N':
354         return MklDnnDims3D::Dim3d_N;
355       case 'C':
356         return MklDnnDims3D::Dim3d_C;
357       case 'D':
358         return MklDnnDims3D::Dim3d_D;
359       case 'H':
360         return MklDnnDims3D::Dim3d_H;
361       case 'W':
362         return MklDnnDims3D::Dim3d_W;
363       default:
364         LOG(FATAL) << "Invalid dimension: " << dimension;
365         return -1;  // Avoid compiler warning about missing return value
366     }
367   }
368 
GetDimension()369   inline size_t GetDimension() const { return data_.dimension_; }
GetSizes()370   inline const int* GetSizes() const {
371     return reinterpret_cast<const int*>(&data_.sizes_[0]);
372   }
373 
374   // Returns an mkldnn::memory::dims object that contains the sizes of this
375   // MklDnnShape object.
GetSizesAsMklDnnDims()376   inline memory::dims GetSizesAsMklDnnDims() const {
377     memory::dims retVal;
378     if (data_.is_mkl_tensor_) {
379       size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
380       for (size_t i = 0; i < dimensions; i++) {
381         if (data_.sizes_[i] != INVALID_DIM_SIZE)
382           retVal.push_back(data_.sizes_[i]);
383       }
384     } else {
385       CHECK_EQ(data_.is_mkl_tensor_, true);
386     }
387     return retVal;
388   }
389 
DimSize(int index)390   inline int64 DimSize(int index) const {
391     CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
392     return data_.sizes_[index];
393   }
394 
395   /// Return TensorShape that describes the Tensorflow shape of the tensor
396   /// represented by this MklShape.
GetTfShape()397   inline TensorShape GetTfShape() const {
398     CHECK_EQ(data_.is_mkl_tensor_, true);
399 
400     std::vector<int32> shape(data_.dimension_, -1);
401     // As mentioned in the comment above, we now rely on TF's `data_format`
402     // attribute to determine if TF shape is in blocked format or not.
403     if (data_.tf_data_format_ != MklTensorFormat::FORMAT_BLOCKED) {
404       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
405         shape[idx] = data_.sizes_[TfDimIdx(idx)];
406       }
407     } else {
408       // If Tensorflow shape is in Blocked format, then we don't have dimension
409       // map for it. So we just create Tensorflow shape from sizes in the
410       // specified order.
411       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
412         shape[idx] = data_.sizes_[idx];
413       }
414     }
415 
416     TensorShape ts;
417     bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
418     CHECK_EQ(ret, true);
419     return ts;
420   }
421 
SetElemType(memory::data_type dt)422   inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
GetElemType()423   inline const memory::data_type GetElemType() { return data_.T_; }
424 
SetMklLayout(memory::desc * md)425   inline void SetMklLayout(memory::desc* md) {
426     CHECK_NOTNULL(md);
427     data_.mkl_md_ = md->data;
428   }
429 
GetMklLayout()430   inline const memory::desc GetMklLayout() const {
431     return memory::desc(data_.mkl_md_);
432   }
433 
GetTfDataFormat()434   inline MklTensorFormat GetTfDataFormat() const {
435     return data_.tf_data_format_;
436   }
437 
438   /// We don't create primitive_descriptor for TensorFlow layout now.
439   /// We use lazy evaluation and create it only when needed. Input format can
440   /// also be Blocked format.
SetTfLayout(size_t dims,const memory::dims & sizes,MklTensorFormat format)441   inline void SetTfLayout(size_t dims, const memory::dims& sizes,
442                           MklTensorFormat format) {
443     DCHECK_EQ(dims, sizes.size())
444         << "SetTfLayout: Number of dimensions does not"
445            "match with dimension array";
446     data_.dimension_ = dims;
447     for (size_t ii = 0; ii < dims; ++ii) {
448       data_.sizes_[ii] = sizes[ii];
449     }
450     data_.tf_data_format_ = format;
451     if (format != MklTensorFormat::FORMAT_BLOCKED) {
452       if (dims == 2) {
453         data_.map_[0] = MklDnnDims::Dim_N;
454         data_.map_[1] = MklDnnDims::Dim_C;
455       } else {
456         SetTfDimOrder(dims, format);
457       }
458     }
459   }
460 
GetTfLayout()461   inline const memory::desc GetTfLayout() const {
462     memory::dims dims;
463     for (size_t ii = 0; ii < data_.dimension_; ++ii) {
464       dims.push_back(data_.sizes_[ii]);
465     }
466 
467     // Create Blocked memory desc if input TF format was set like that.
468     if (data_.tf_data_format_ == MklTensorFormat::FORMAT_BLOCKED) {
469       auto strides = CalculateTFStrides(dims);
470       mkldnn_memory_desc_t blocked_md;
471       TF_CHECK_OK(
472           CreateBlockedMemDescHelper(dims, strides, data_.T_, &blocked_md));
473       return memory::desc(blocked_md);
474     } else {
475       auto format_tag =
476           MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_);
477       DCHECK_NE(format_tag, memory::format_tag::undef);
478       return memory::desc(dims, data_.T_, format_tag);
479     }
480   }
481 
GetCurLayout()482   inline const memory::desc GetCurLayout() const {
483     return IsMklTensor() ? GetMklLayout() : GetTfLayout();
484   }
485 
486   // We don't need a case of default dimension order because
487   // when an operator that does not get data_format attribute gets all inputs
488   // in Tensorflow format, it will produce output in Tensorflow format.
SetTfDimOrder(const size_t dimension,const mkldnn_dims_t map)489   inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
490     CHECK(dimension == data_.dimension_);
491     for (size_t ii = 0; ii < dimension; ii++) {
492       data_.map_[ii] = map[ii];
493     }
494   }
495 
SetTfDimOrder(const size_t dimension,TensorFormat data_format)496   inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
497     if (dimension == 5) {
498       CHECK(dimension == data_.dimension_);
499       data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
500           MklDnnDims3D::Dim3d_D;
501       data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
502           MklDnnDims3D::Dim3d_H;
503       data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
504           MklDnnDims3D::Dim3d_W;
505       data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
506           MklDnnDims3D::Dim3d_C;
507       data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
508           MklDnnDims3D::Dim3d_N;
509     } else {
510       CHECK_EQ(dimension, 4);
511       CHECK(dimension == data_.dimension_);
512       data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
513       data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
514       data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
515       data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
516     }
517   }
518 
SetTfDimOrder(const size_t dimension,MklTensorFormat format)519   inline void SetTfDimOrder(const size_t dimension, MklTensorFormat format) {
520     TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
521     SetTfDimOrder(dimension, data_format);
522   }
523 
GetTfToMklDimMap()524   inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
TfDimIdx(int index)525   inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
TfDimSize(int index)526   inline int64 TfDimSize(int index) const {
527     return data_.sizes_[TfDimIdx(index)];
528   }
529 
530   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
531   /// corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)532   inline bool IsMklChannelDim(int d) const {
533     return TfDimIdx(d) == MklDnnDims::Dim_C;
534   }
535 
536   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
537   /// corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)538   inline bool IsMklBatchDim(int d) const {
539     return TfDimIdx(d) == MklDnnDims::Dim_N;
540   }
541 
542   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
543   /// corresponds to MKL's Width dimension.
IsMklWidthDim(int d)544   inline bool IsMklWidthDim(int d) const {
545     return TfDimIdx(d) == MklDnnDims::Dim_W;
546   }
547   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
548   /// corresponds to MKL's Height dimension.
IsMklHeightDim(int d)549   inline bool IsMklHeightDim(int d) const {
550     return TfDimIdx(d) == MklDnnDims::Dim_H;
551   }
552 
553   /// Check if the TF-MKL dimension ordering map specifies if the input
554   /// tensor is in NCHW format.
IsTensorInNCHWFormat()555   inline bool IsTensorInNCHWFormat() const {
556     TensorFormat data_format = FORMAT_NCHW;
557     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
558             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
559             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
560             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
561   }
562 
563   /// Check if the TF-MKL dimension ordering map specifies if the input
564   /// tensor is in NHWC format.
IsTensorInNHWCFormat()565   inline bool IsTensorInNHWCFormat() const {
566     TensorFormat data_format = FORMAT_NHWC;
567     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
568             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
569             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
570             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
571   }
572 
573   /// The following methods are used for serializing and de-serializing the
574   /// contents of the mklshape object.
575   /// The data is serialized in this order
576   /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
577 
578   /// Size of buffer to hold the serialized object, the size is computed by
579   /// following above mentioned order
GetSerializeBufferSize()580   inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
581 
SerializeMklDnnShape(unsigned char * buf,size_t buf_size)582   void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
583     CHECK(buf_size >= GetSerializeBufferSize())
584         << "Buffer size is too small to SerializeMklDnnShape";
585     *reinterpret_cast<MklShapeData*>(buf) = data_;
586   }
587 
DeSerializeMklDnnShape(const unsigned char * buf,size_t buf_size)588   void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
589     // Make sure buffer holds at least is_mkl_tensor_.
590     CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
591         << "Buffer size is too small in DeSerializeMklDnnShape";
592 
593     const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
594     if (is_mkl_tensor) {  // If it is an MKL Tensor then read the rest
595       CHECK(buf_size >= GetSerializeBufferSize())
596           << "Buffer size is too small in DeSerializeMklDnnShape";
597       data_ = *reinterpret_cast<const MklShapeData*>(buf);
598     }
599   }
600 };
601 
602 // List of MklShape objects. Used in Concat/Split layers.
603 typedef std::vector<MklDnnShape> MklDnnShapeList;
604 
605 template <typename T>
606 class MklDnnData;
607 
608 // TODO merge with the execute_primitives.
609 inline void ExecutePrimitive(const std::vector<primitive>& net,
610                              const std::vector<MemoryArgsMap>* net_args,
611                              const engine& cpu_engine,
612                              OpKernelContext* context = nullptr) {
613   DCHECK(net_args);
614   DCHECK_EQ(net.size(), net_args->size());
615   stream* cpu_stream = CreateStream(context, cpu_engine);
616   for (size_t i = 0; i < net.size(); ++i) {
617     net.at(i).execute(*cpu_stream, net_args->at(i));
618   }
619   cpu_stream->wait();
620   delete cpu_stream;
621 }
622 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & input_mkl_tensor,const MklDnnShape & input_mkl_shape,Tensor * output_tf_tensor)623 inline Status ConvertMklToTF(OpKernelContext* context,
624                              const Tensor& input_mkl_tensor,
625                              const MklDnnShape& input_mkl_shape,
626                              Tensor* output_tf_tensor) {
627   try {
628     if (!input_mkl_shape.IsMklTensor()) {
629       // Return input as is since it is already a TF tensor
630       *output_tf_tensor = input_mkl_tensor;
631       return Status::OK();
632     }
633 
634     // Allocate output tensor.
635     TensorShape output_tf_shape = input_mkl_shape.GetTfShape();
636     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<T>::v(), output_tf_shape,
637                                        output_tf_tensor));
638 
639     engine cpu_engine(engine::kind::cpu, 0);
640     MklDnnData<T> input(&cpu_engine);
641 
642     // Get MKL layout of input tensor.
643     auto input_mkl_md = input_mkl_shape.GetMklLayout();
644     auto output_tf_md = input_mkl_shape.GetTfLayout();
645     input.SetUsrMem(input_mkl_md, &input_mkl_tensor);
646 
647     if (input.IsReorderNeeded(output_tf_md)) {
648       std::vector<primitive> net;
649       std::vector<MemoryArgsMap> net_args;
650       bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor,
651                                               net, net_args, cpu_engine);
652       if (!status) {
653         return Status(error::Code::INTERNAL,
654                       "ConvertMklToTF(): Failed to create reorder for input");
655       }
656       ExecutePrimitive(net, &net_args, cpu_engine, context);
657     } else {
658       // If not, just forward input tensor to output tensor.
659       bool status =
660           output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape);
661       if (!status) {
662         return Status(
663             error::Code::INTERNAL,
664             "ConvertMklToTF(): Failed to forward input tensor to output");
665       }
666     }
667     return Status::OK();
668   } catch (mkldnn::error& e) {
669     string error_msg = "Status: " + std::to_string(e.status) +
670                        ", message: " + string(e.message) + ", in file " +
671                        string(__FILE__) + ":" + std::to_string(__LINE__);
672     LOG(FATAL) << "Operation received an exception: " << error_msg;
673   }
674 }
675 
676 // Get the MKL shape from the second string tensor
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape,bool eager_mode)677 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
678                         bool eager_mode) {
679   if (!eager_mode) {
680     mklshape->DeSerializeMklDnnShape(
681         ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
682             .flat<uint8>()
683             .data(),
684         ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
685                 .flat<uint8>()
686                 .size() *
687             sizeof(uint8));
688   } else {
689     mklshape->SetMklTensor(false);
690   }
691 }
692 
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape)693 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
694   GetMklShape(ctext, n, mklshape, false);
695 }
696 
697 // Gets the actual input
MklGetInput(OpKernelContext * ctext,int n)698 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
699   return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
700 }
701 
GetMklInputList(OpKernelContext * ctext,StringPiece name,OpInputList * input_tensors)702 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
703                             OpInputList* input_tensors) {
704   CHECK_NOTNULL(input_tensors);
705   TF_CHECK_OK(ctext->input_list(name, input_tensors));
706 }
707 
GetMklShapeList(OpKernelContext * ctext,StringPiece name,MklDnnShapeList * mkl_shapes)708 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
709                             MklDnnShapeList* mkl_shapes) {
710   OpInputList input_mkl_tensors;
711   GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
712 
713   for (int i = 0; i < input_mkl_tensors.size(); i++) {
714     (*mkl_shapes)[i].DeSerializeMklDnnShape(
715         input_mkl_tensors[i].flat<uint8>().data(),
716         input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
717   }
718 }
719 
720 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
721 /// If the input tensor is in MKL layout, then obtains TensorShape from
722 /// MklShape.
723 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
724                               bool eager_mode = false) {
725   // Sanity check.
726   CHECK_NOTNULL(context);
727   CHECK_LT(input_idx, context->num_inputs());
728 
729   MklDnnShape input_mkl_shape;
730   GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
731   if (input_mkl_shape.IsMklTensor() && !eager_mode) {
732     return input_mkl_shape.GetTfShape();
733   } else {
734     const Tensor& t = MklGetInput(context, input_idx);
735     return t.shape();
736   }
737 }
738 
739 // Allocate the second output tensor that will contain
740 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklDnnShape & mkl_shape)741 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
742                                       const MklDnnShape& mkl_shape) {
743   Tensor* second_tensor = nullptr;
744   TensorShape second_shape;
745   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
746   OP_REQUIRES_OK(ctext, ctext->allocate_output(
747                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
748                             second_shape, &second_tensor));
749   mkl_shape.SerializeMklDnnShape(
750       second_tensor->flat<uint8>().data(),
751       second_tensor->flat<uint8>().size() * sizeof(uint8));
752 }
753 
754 // Allocate the output tensor, create a second output tensor that will contain
755 // the MKL shape serialized
756 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
757                                       Tensor** output,
758                                       const TensorShape& tf_shape,
759                                       const MklDnnShape& mkl_shape,
760                                       bool eager_mode = false) {
761   OP_REQUIRES_OK(
762       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
763                                     tf_shape, output));
764   if (!eager_mode) {
765     Tensor* second_tensor = nullptr;
766     TensorShape second_shape;
767     second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
768     OP_REQUIRES_OK(ctext, ctext->allocate_output(
769                               GetTensorMetaDataIndex(n, ctext->num_outputs()),
770                               second_shape, &second_tensor));
771     mkl_shape.SerializeMklDnnShape(
772         second_tensor->flat<uint8>().data(),
773         second_tensor->flat<uint8>().size() * sizeof(uint8));
774   }
775 }
776 
777 // Allocates a temp tensor and returns the data buffer for temporary storage.
778 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,const memory::desc & pd,void ** buf_out)779 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
780                            const memory::desc& pd, void** buf_out) {
781   TensorShape tf_shape;
782 
783   tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
784   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
785                                                  tf_shape, tensor_out));
786   *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
787 }
788 
789 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,TensorShape tf_shape)790 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
791                            TensorShape tf_shape) {
792   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
793                                                  tf_shape, tensor_out));
794 }
795 
GetStridesFromSizes(MklTensorFormat data_format,size_t * strides,const size_t * sizes)796 inline void GetStridesFromSizes(MklTensorFormat data_format, size_t* strides,
797                                 const size_t* sizes) {
798   DCHECK_NE(data_format, MklTensorFormat::FORMAT_INVALID);
799   // MKL requires strides in NCHW
800   if (data_format == MklTensorFormat::FORMAT_NHWC) {
801     strides[0] = sizes[2];
802     strides[1] = sizes[0] * sizes[2];
803     strides[2] = 1;
804     strides[3] = sizes[0] * sizes[1] * sizes[2];
805   } else {
806     strides[0] = 1;
807     strides[1] = sizes[0];
808     strides[2] = sizes[0] * sizes[1];
809     strides[3] = sizes[0] * sizes[1] * sizes[2];
810   }
811 }
812 
CopyMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)813 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
814                                  int idx_out) {
815   int num_inputs = context->num_inputs();
816   int num_outputs = context->num_outputs();
817   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
818   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
819   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
820   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
821 
822   const Tensor& data = context->input(idx_data_in);
823   const Tensor& meta = context->input(idx_meta_in);
824   Tensor output(data.dtype());
825   Tensor meta_output(meta.dtype());
826 
827   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
828   CHECK(output.CopyFrom(data, data.shape()));
829   CHECK(meta_output.CopyFrom(meta, meta.shape()));
830   context->set_output(idx_data_out, output);
831   context->set_output(idx_meta_out, meta_output);
832 }
833 
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)834 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
835                                          int idx_out,
836                                          const TensorShape& shape) {
837   int num_inputs = context->num_inputs();
838   int num_outputs = context->num_outputs();
839   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
840   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
841 
842   const Tensor& data = context->input(idx_data_in);
843   MklDnnShape mkl_shape_output;
844   mkl_shape_output.SetMklTensor(false);
845   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
846   Tensor output(data.dtype());
847   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
848   CHECK(output.CopyFrom(data, shape));
849   context->set_output(idx_data_out, output);
850 }
851 
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)852 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
853                                    int idx_out) {
854   int num_inputs = context->num_inputs();
855   int num_outputs = context->num_outputs();
856   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
857   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
858 
859   MklDnnShape dnn_shape_output;
860   dnn_shape_output.SetMklTensor(false);
861   AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
862   if (IsRefType(context->input_dtype(idx_data_in))) {
863     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
864   } else {
865     context->set_output(idx_data_out, context->input(idx_data_in));
866   }
867 }
868 
ForwardMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)869 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
870                                     int idx_out) {
871   int num_inputs = context->num_inputs();
872   int num_outputs = context->num_outputs();
873   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
874   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
875   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
876   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
877 
878   if (IsRefType(context->input_dtype(idx_data_in))) {
879     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
880     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
881   } else {
882     context->set_output(idx_data_out, context->input(idx_data_in));
883     context->set_output(idx_meta_out, context->input(idx_meta_in));
884   }
885 }
886 
887 // Set a dummy MKLDNN shape (called when the output is in TF format)
SetDummyMklDnnShapeOutput(OpKernelContext * context,uint32 idx_data_out)888 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
889                                       uint32 idx_data_out) {
890   MklDnnShape mkl_shape_output;
891   mkl_shape_output.SetMklTensor(false);
892   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
893 }
894 
895 // If the input tensor has ref count as 1, it is forwarded to the desired
896 // output port and the function returns true. In that case, it also allocates
897 // the serialized MklDnnShape object. Otherwise, the function returns false.
898 inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
899                                                 int idx_in, int idx_out,
900                                                 Tensor** output,
901                                                 const MklDnnShape& mkl_shape,
902                                                 bool always_forward = true) {
903   int num_inputs = context->num_inputs();
904   int num_outputs = context->num_outputs();
905   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
906   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
907   bool is_forwarded = false;
908   const Tensor& input_tensor = context->input(idx_data_in);
909   const auto output_shape = input_tensor.shape();
910   if (always_forward) {
911     if (IsRefType(context->input_dtype(idx_data_in))) {
912       context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
913     } else {
914       context->set_output(idx_data_out, input_tensor);
915     }
916   } else {
917     is_forwarded = context->forward_input_to_output_with_shape(
918         idx_data_in, idx_data_out, output_shape, output);
919   }
920   if (is_forwarded || always_forward) {
921     AllocateOutputSetMklShape(context, idx_out, mkl_shape);
922     return true;
923   }
924   return false;
925 }
926 
927 // Forward the MKL shape ONLY (used in elementwise and other ops where
928 // we call the eigen implementation and MKL shape is not used)
ForwardMklMetaDataInToOut(OpKernelContext * context,uint32 idx_data_in,uint32_t idx_data_out)929 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
930                                       uint32 idx_data_in,
931                                       uint32_t idx_data_out) {
932   uint32 idx_meta_in =
933       GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
934   uint32 idx_meta_out =
935       GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
936 
937   if (IsRefType(context->input_dtype(idx_data_in))) {
938     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
939   } else {
940     context->set_output(idx_meta_out, context->input(idx_meta_in));
941   }
942 }
943 
944 // -------------------------------------------------------------------
945 //          Common utility functions used by MKL unit tests
946 
GetMklMetaTensor()947 inline Tensor GetMklMetaTensor() {
948   MklDnnShape non_mkl_shape;
949   non_mkl_shape.SetMklTensor(false);
950 
951   auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
952   Tensor tensor(DT_UINT8, {size});
953 
954   non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
955                                      size * sizeof(uint8));
956   return tensor;
957 }
958 
959 // -------------------------------------------------------------------
960 
961 /// Return MKL-DNN data type (memory::data_type) for input type T
962 ///
963 /// @input None
964 /// @return memory::data_type corresponding to type T
965 template <typename T>
966 static memory::data_type MklDnnType();
967 
968 /// Instantiation for float type. Add similar instantiations for other
969 /// type if needed.
970 template <>
971 memory::data_type MklDnnType<float>() {
972   return memory::data_type::f32;
973 }
974 
975 template <>
976 memory::data_type MklDnnType<quint8>() {
977   return memory::data_type::u8;
978 }
979 
980 template <>
981 memory::data_type MklDnnType<uint8>() {
982   return memory::data_type::u8;
983 }
984 
985 template <>
986 memory::data_type MklDnnType<qint8>() {
987   return memory::data_type::s8;
988 }
989 
990 template <>
991 memory::data_type MklDnnType<qint32>() {
992   return memory::data_type::s32;
993 }
994 template <>
995 memory::data_type MklDnnType<bfloat16>() {
996 #ifdef ENABLE_INTEL_MKL_BFLOAT16
997   return memory::data_type::bf16;
998 #else
999   return memory::data_type::f32;
1000 #endif
1001 }
1002 
1003 // Map MklTensorFormat to MKL-DNN format tag
1004 //
1005 // @input: MklTensorFormat i.e. TensorFlow data format
1006 // @return: MKL-DNN's memory format tag corresponding to MklTensorFormat.
1007 //          Fails with an error if invalid data format.
MklTensorFormatToMklDnnDataFormat(MklTensorFormat format)1008 inline memory::format_tag MklTensorFormatToMklDnnDataFormat(
1009     MklTensorFormat format) {
1010   if (format == MklTensorFormat::FORMAT_NHWC) return memory::format_tag::nhwc;
1011   if (format == MklTensorFormat::FORMAT_NCHW) return memory::format_tag::nchw;
1012   if (format == MklTensorFormat::FORMAT_NDHWC) return memory::format_tag::ndhwc;
1013   if (format == MklTensorFormat::FORMAT_NCDHW) return memory::format_tag::ncdhw;
1014   if (format == MklTensorFormat::FORMAT_X) return memory::format_tag::x;
1015   if (format == MklTensorFormat::FORMAT_NC) return memory::format_tag::nc;
1016   if (format == MklTensorFormat::FORMAT_TNC) return memory::format_tag::tnc;
1017   return memory::format_tag::undef;
1018 }
1019 
1020 /// Map TensorFlow data format into MKL-DNN 3D data format
1021 /// @input: TensorFlow data format
1022 /// @return: MKL-DNN 3D data format corresponding to TensorFlow data format;
1023 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnn3DDataFormat(TensorFormat format)1024 inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1025   if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC;
1026   if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW;
1027   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1028   return MklTensorFormat::FORMAT_INVALID;
1029 }
1030 
1031 /// Map TensorFlow data format into MKL-DNN data format
1032 ///
1033 /// @input: TensorFlow data format
1034 /// @return: MKL-DNN data format corresponding to TensorFlow data format;
1035 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnnDataFormat(TensorFormat format)1036 inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1037   if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC;
1038   if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW;
1039   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1040   return MklTensorFormat::FORMAT_INVALID;
1041 }
1042 
1043 /// Map MKL-DNN data format into TensorFlow data format
1044 ///
1045 /// @input: MKL-DNN data format
1046 /// @return: Tensorflow data format corresponding to MKL-DNN data format;
1047 ///          Fails with an error if invalid data format.
MklDnnDataFormatToTFDataFormat(MklTensorFormat format)1048 inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) {
1049   if (format == MklTensorFormat::FORMAT_NHWC ||
1050       format == MklTensorFormat::FORMAT_NDHWC)
1051     return FORMAT_NHWC;
1052   if (format == MklTensorFormat::FORMAT_NCHW ||
1053       format == MklTensorFormat::FORMAT_NCDHW)
1054     return FORMAT_NCHW;
1055   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1056 
1057   // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1058   // that we don't come here.
1059   return FORMAT_NHWC;
1060 }
1061 
1062 /// Map TensorShape object into memory::dims required by MKL-DNN
1063 ///
1064 /// This function will simply map input TensorShape into MKL-DNN dims
1065 /// naively. So it will preserve the order of dimensions. E.g., if
1066 /// input tensor is in NHWC format, then dims will be in NHWC format also.
1067 ///
1068 /// @input TensorShape object in shape
1069 /// @return memory::dims corresponding to TensorShape
TFShapeToMklDnnDims(const TensorShape & shape)1070 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1071   memory::dims dims(shape.dims());
1072   for (int d = 0; d < shape.dims(); ++d) {
1073     dims[d] = shape.dim_size(d);
1074   }
1075   return dims;
1076 }
1077 
1078 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
1079 ///
1080 /// This function is a specific one than above function. It will map input
1081 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
1082 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1083 /// will be in NCHW format, and not in NHWC format.
1084 ///
1085 /// @input TensorShape object in shape
1086 /// @return memory::dims in MKL-DNN required NCHW format
TFShapeToMklDnnDimsInNCHW(const TensorShape & shape,TensorFormat format)1087 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1088                                               TensorFormat format) {
1089   // Check validity of format.
1090   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1091             MklTensorFormat::FORMAT_INVALID);
1092 
1093   int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1094   int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1095   int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1096   int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1097 
1098   // MKL-DNN requires dimensions in NCHW format.
1099   return memory::dims({n, c, h, w});
1100 }
1101 
TFShapeToMklDnnDimsInNCDHW(const TensorShape & shape,TensorFormat format)1102 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1103                                                TensorFormat format) {
1104   // Validate format.
1105   DCHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1106             MklTensorFormat::FORMAT_INVALID);
1107 
1108   int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1109   int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1110   int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1111   int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1112   int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1113 
1114   // MKL-DNN requires dimensions in NCDHW format.
1115   return memory::dims({n, c, d, h, w});
1116 }
1117 
1118 /// Overloaded version of function TFShapeToMklDnnDimsInNCHW above.
1119 /// Input parameters are self-explanatory.
MklDnnDimsInNCHW(const memory::dims & in_dims,TensorFormat format)1120 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1121                                      TensorFormat format) {
1122   // Validate format.
1123   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1124             MklTensorFormat::FORMAT_INVALID);
1125 
1126   int n = in_dims[GetTensorDimIndex(format, 'N')];
1127   int c = in_dims[GetTensorDimIndex(format, 'C')];
1128   int h = in_dims[GetTensorDimIndex(format, 'H')];
1129   int w = in_dims[GetTensorDimIndex(format, 'W')];
1130 
1131   // MKL-DNN requires dimensions in NCHW format.
1132   return memory::dims({n, c, h, w});
1133 }
1134 
1135 /// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above.
1136 /// Input parameters are self-explanatory.
MklDnnDimsInNCDHW(const memory::dims & in_dims,TensorFormat format)1137 inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims,
1138                                       TensorFormat format) {
1139   // Validate format.
1140   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1141             MklTensorFormat::FORMAT_INVALID);
1142 
1143   int n = in_dims[GetTensorDimIndex<3>(format, 'N')];
1144   int c = in_dims[GetTensorDimIndex<3>(format, 'C')];
1145   int d = in_dims[GetTensorDimIndex<3>(format, '0')];
1146   int h = in_dims[GetTensorDimIndex<3>(format, '1')];
1147   int w = in_dims[GetTensorDimIndex<3>(format, '2')];
1148 
1149   // MKL DNN requires dimensions in NCDHW format.
1150   return memory::dims({n, c, d, h, w});
1151 }
1152 
1153 /// Map MklDnn memory::dims object into TensorShape object.
1154 ///
1155 /// This function will simply map input shape in MKL-DNN memory::dims format
1156 /// in Tensorflow's TensorShape object by preserving dimension order.
1157 ///
1158 /// @input MKL-DNN memory::dims object
1159 /// @output TensorShape corresponding to memory::dims
MklDnnDimsToTFShape(const memory::dims & dims)1160 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1161   std::vector<int32> shape(dims.size(), -1);
1162   for (int d = 0; d < dims.size(); d++) {
1163     shape[d] = dims[d];
1164   }
1165 
1166   TensorShape ret;
1167   CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1168   return ret;
1169 }
1170 
1171 /// Function to calculate strides given tensor shape in Tensorflow order
1172 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1173 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
1174 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1175 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1176 ///
1177 /// @input Tensorflow shape in memory::dims type
1178 /// @return memory::dims containing strides for the tensor.
CalculateTFStrides(const memory::dims & dims_tf_order)1179 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1180   CHECK_GT(dims_tf_order.size(), 0);
1181   memory::dims strides(dims_tf_order.size());
1182   int last_dim_idx = dims_tf_order.size() - 1;
1183   strides[last_dim_idx] = 1;
1184   for (int d = last_dim_idx - 1; d >= 0; d--) {
1185     strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1186   }
1187   return strides;
1188 }
1189 
1190 /// Helper function to create memory descriptor in Blocked format
1191 ///
1192 /// @input: Tensor dimensions
1193 /// @input: strides corresponding to dimensions. One can use utility
1194 ///         function such as CalculateTFStrides to compute strides
1195 ///         for given dimensions.
1196 /// @output: mkldnn_memory_desc_t object corresponding to blocked memory
1197 ///          format for given dimensions and strides.
1198 /// @return: Status indicating whether the blocked memory descriptor
1199 ///          was successfully created.
CreateBlockedMemDescHelper(const memory::dims & dim,const memory::dims & strides,memory::data_type dtype,mkldnn_memory_desc_t * blocked_md)1200 inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
1201                                          const memory::dims& strides,
1202                                          memory::data_type dtype,
1203                                          mkldnn_memory_desc_t* blocked_md) {
1204   DCHECK_EQ(dim.size(), strides.size());
1205   const int kNumDims = dim.size();
1206   mkldnn_dim_t* input_dims = new mkldnn_dim_t[kNumDims];
1207   mkldnn_dim_t* input_strides = new mkldnn_dim_t[kNumDims];
1208   for (int i = 0; i < kNumDims; ++i) {
1209     input_dims[i] = dim[i];
1210     input_strides[i] = strides[i];
1211   }
1212   try {
1213     mkldnn_memory_desc_init_by_strides(blocked_md, kNumDims, input_dims,
1214                                        memory::convert_to_c(dtype),
1215                                        input_strides);
1216     delete[] input_dims;
1217     delete[] input_strides;
1218   } catch (mkldnn::error& e) {
1219     delete[] input_dims;
1220     delete[] input_strides;
1221     return Status(error::Code::INTERNAL,
1222                   tensorflow::strings::StrCat(
1223                       "Failed to create blocked memory descriptor.",
1224                       "Status: ", e.status, ", message: ", e.message));
1225   }
1226   return Status::OK();
1227 }
1228 
1229 inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
1230                                     const memory& src_mem,
1231                                     const memory& dst_mem, const engine& engine,
1232                                     OpKernelContext* ctx = nullptr) {
1233   std::vector<primitive> net;
1234   net.push_back(mkldnn::reorder(reorder_desc));
1235   std::vector<MemoryArgsMap> net_args;
1236   net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
1237   ExecutePrimitive(net, &net_args, engine, ctx);
1238 }
1239 
1240 class MklReorderPrimitive;
1241 
1242 template <typename T>
1243 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1244                                                 const memory* to);
1245 
1246 // Class to represent all the resources corresponding to a tensor in TensorFlow
1247 // that are required to execute an operation (such as Convolution).
1248 template <typename T>
1249 class MklDnnData {
1250  private:
1251   /// MKL-DNN memory primitive for input user memory
1252   memory* user_memory_;
1253 
1254   /// MKL-DNN memory primitive in case input or output reorder is needed.
1255   memory* reorder_memory_;
1256 
1257   /// Operations memory descriptor
1258   memory::desc* op_md_;
1259   // flat to indicate if data is 3D or not.
1260   bool bIs3D;
1261   /// Operations temp buffer
1262   void* allocated_buffer_;
1263   /// CPU engine on which operation will be executed
1264   const engine* cpu_engine_;
1265 
1266  public:
MklDnnData(const engine * e)1267   explicit MklDnnData(const engine* e)
1268       : user_memory_(nullptr),
1269         reorder_memory_(nullptr),
1270         op_md_(nullptr),
1271         bIs3D(false),
1272         allocated_buffer_(nullptr),
1273         cpu_engine_(e) {}
1274 
~MklDnnData()1275   ~MklDnnData() {
1276     if (allocated_buffer_ != nullptr) {
1277       cpu_allocator()->DeallocateRaw(allocated_buffer_);
1278     }
1279     cpu_engine_ = nullptr;  // We don't own this.
1280     delete (user_memory_);
1281     delete (reorder_memory_);
1282     delete (op_md_);
1283   }
1284 
GetTensorBuffer(const Tensor * tensor)1285   inline void* GetTensorBuffer(const Tensor* tensor) const {
1286     CHECK_NOTNULL(tensor);
1287     return const_cast<void*>(
1288         static_cast<const void*>(tensor->flat<T>().data()));
1289   }
1290 
SetIs3DData(bool bIs3D_)1291   void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
GetIs3D()1292   bool GetIs3D() { return bIs3D; }
1293 
1294   /// Set user memory primitive using specified dimensions, memory format tag
1295   /// and data_buffer. Function automatically uses element data type by using
1296   /// input type T used for creating call object.
1297   ///
1298   /// In a nutshell, function allows user to describe the input tensor to
1299   /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1300   /// memory format tag HWIO, and the buffer that contains actual values is
1301   /// pointed by data_buffer.
1302   inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1303                         void* data_buffer = nullptr) {
1304     auto md = memory::desc(dim, MklDnnType<T>(), fm);
1305     SetUsrMem(md, data_buffer);
1306   }
1307 
SetUsrMem(const memory::dims & dim,memory::format_tag fm,const Tensor * tensor)1308   inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1309                         const Tensor* tensor) {
1310     DCHECK(tensor);
1311     SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1312   }
1313 
1314   /// Helper function to create memory descriptor in Blocked format
1315   ///
1316   /// @input: Tensor dimensions
1317   /// @input: strides corresponding to dimensions. One can use utility
1318   ///         function such as CalculateTFStrides to compute strides
1319   ///         for given dimensions.
1320   /// @return: memory::desc object corresponding to blocked memory format
1321   ///          for given dimensions and strides.
CreateBlockedMemDesc(const memory::dims & dim,const memory::dims & strides)1322   static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1323                                                   const memory::dims& strides) {
1324     mkldnn_memory_desc_t blocked_md;
1325     TF_CHECK_OK(
1326         CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>(), &blocked_md));
1327     return memory::desc(blocked_md);
1328   }
1329 
1330   /// A version of SetUsrMem call that allows user to create memory in blocked
1331   /// format. So in addition to accepting dimensions, it also accepts strides.
1332   /// This allows user to create memory for tensor in a format that is not
1333   /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
1334   /// dimensional tensor as a native format. But by using blocked format, a user
1335   /// can create memory for 6D tensor.
1336   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1337                         void* data_buffer = nullptr) {
1338     CHECK_EQ(dim.size(), strides.size());
1339     auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1340     SetUsrMem(blocked_md, data_buffer);
1341   }
1342 
SetUsrMem(const memory::dims & dim,const memory::dims & strides,const Tensor * tensor)1343   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1344                         const Tensor* tensor) {
1345     CHECK_NOTNULL(tensor);
1346     SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1347   }
1348 
1349   /// A version of SetUsrMem with memory descriptor and tensor
SetUsrMem(const memory::desc & md,const Tensor * tensor)1350   inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1351     CHECK_NOTNULL(tensor);
1352     SetUsrMem(md, GetTensorBuffer(tensor));
1353   }
1354 
1355   /// A version of function to set user memory type that accepts memory
1356   /// descriptor directly, instead of accepting dimensions and format. This
1357   /// function is more generic than the one above, but the function above is
1358   /// sufficient in most cases.
1359   inline void SetUsrMem(const memory::desc& pd, void* data_buffer = nullptr) {
1360     DCHECK(cpu_engine_);
1361     if (user_memory_) delete user_memory_;
1362     // TODO(nhasabni): can we remove dynamic memory allocation?
1363     if (data_buffer) {
1364       user_memory_ = new memory(pd, *cpu_engine_, data_buffer);
1365     } else {
1366       user_memory_ = new memory(pd, *cpu_engine_);
1367     }
1368   }
1369 
1370   /// Get function for user memory primitive.
GetUsrMem()1371   inline const memory* GetUsrMem() const { return user_memory_; }
1372 
1373   /// Get function for descriptor of user memory.
GetUsrMemDesc()1374   inline memory::desc GetUsrMemDesc() const {
1375     DCHECK(user_memory_);
1376     return user_memory_->get_desc();
1377   }
1378 
1379   /// Get function for data buffer of user memory primitive.
GetUsrMemDataHandle()1380   inline void* GetUsrMemDataHandle() const {
1381     CHECK_NOTNULL(user_memory_);
1382     return user_memory_->get_data_handle();
1383   }
1384 
1385   /// Set function for data buffer of user memory primitive.
1386   inline void SetUsrMemDataHandle(void* data_buffer,
1387                                   std::shared_ptr<stream> t_stream = nullptr) {
1388     CHECK_NOTNULL(user_memory_);
1389     CHECK_NOTNULL(data_buffer);
1390 #ifdef ENABLE_MKLDNN_THREADPOOL
1391     user_memory_->set_data_handle(data_buffer, *t_stream);
1392 #else
1393     user_memory_->set_data_handle(data_buffer);
1394 #endif  // ENABLE_MKLDNN_THREADPOOL
1395   }
1396 
1397   /// Set function for data buffer of user memory primitive.
1398   inline void SetUsrMemDataHandle(const Tensor* tensor,
1399                                   std::shared_ptr<stream> t_stream = nullptr) {
1400     SetUsrMemDataHandle(GetTensorBuffer(tensor), t_stream);
1401   }
1402 
1403   /// allocate function for data buffer
AllocateBuffer(size_t size)1404   inline void AllocateBuffer(size_t size) {
1405     const int64 kMemoryAlignment = 64;  // For AVX512 memory alignment.
1406     allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlignment, size);
1407   }
1408 
GetAllocatedBuffer()1409   inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1410 
1411   /// Get the memory primitive for input and output of an op. If inputs
1412   /// to an op require reorders, then this function returns memory primitive
1413   /// for reorder. Otherwise, it will return memory primitive for user memory.
1414   ///
1415   /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1416   /// execute Conv2D, we need memory primitive for I and F. But if reorder is
1417   /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1418   /// primitive for F), then we need I_r and F_r to perform Conv2D.
GetOpMem()1419   inline const memory& GetOpMem() const {
1420     return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1421   }
1422 
1423   /// Set memory descriptor of an operation in terms of dimensions and memory
1424   /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1425   /// but memory::format_tag would be mkldnn::any because we want MKL-DNN to
1426   /// choose the best layout/format for given input dimensions.
SetOpMemDesc(const memory::dims & dim,memory::format_tag fm)1427   inline void SetOpMemDesc(const memory::dims& dim, memory::format_tag fm) {
1428     // TODO(nhasabni): can we remove dynamic memory allocation?
1429     op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1430   }
1431 
1432   /// Get function for memory descriptor for an operation
GetOpMemDesc()1433   inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1434 
1435   /// Predicate that checks if we need to reorder user's memory into memory
1436   /// pointed by op_md.
1437   ///
1438   /// @input: op_md - memory descriptor of the given input of an operation.
1439   /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::desc & op_pd)1440   inline bool IsReorderNeeded(const memory::desc& op_pd) const {
1441     DCHECK(user_memory_);
1442     return op_pd != user_memory_->get_desc();
1443   }
1444 
1445   /// Function to create a reorder from memory pointed by from to memory pointed
1446   /// by to. Returns created primitive.
CreateReorder(const memory * from,const memory * to)1447   inline primitive CreateReorder(const memory* from, const memory* to) const {
1448     CHECK_NOTNULL(from);
1449     CHECK_NOTNULL(to);
1450     return reorder(*from, *to);
1451   }
1452 
1453   /// Function to handle input reordering
1454   ///
1455   /// Check if we need to reorder this input of an operation.
1456   /// Return true and allocate reorder memory primitive if reorder is needed.
1457   /// Otherwise, return false and do not allocate reorder memory primitive.
1458   ///
1459   /// To check if reorder is needed, this function compares memory primitive
1460   /// descriptor (memory descriptor for v1.x) of an operation (op_pd) for
1461   /// the given input with the user-specified memory descriptor.
1462   ///
1463   /// @input: op_pd - memory primitive descriptor of the given input of an
1464   ///                 operation
1465   /// @input: net - net to which to add reorder primitive in case it is needed.
1466   /// @input: net_args - net to which user and reorder memories are added if
1467   ///                    needed. Each entry is a key-value pair of the form
1468   ///                    <argument-type, mkldnn::memory>.
1469   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1470   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1471                                   std::vector<primitive>& net,
1472                                   std::vector<MemoryArgsMap>& net_args,
1473                                   const engine& engine) {
1474     DCHECK(user_memory_);
1475     DCHECK_EQ(net.size(), net_args.size());
1476     if (IsReorderNeeded(op_md)) {
1477       // TODO(nhasabni): can we remove dynamic memory allocation?
1478       reorder_memory_ = new memory(op_md, engine);
1479       net.push_back(CreateReorder(user_memory_, reorder_memory_));
1480       net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1481                                        {MKLDNN_ARG_TO, *reorder_memory_}});
1482       return true;
1483     }
1484     return false;
1485   }
1486 
1487   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1488                                   const engine& engine,
1489                                   OpKernelContext* context = nullptr) {
1490     DCHECK(user_memory_);
1491     if (IsReorderNeeded(op_md)) {
1492       // TODO(nhasabni): can we remove dynamic memory allocation?
1493       // primitive reuse don't allow two same reorder prim in
1494       // one stream, so submit it immediately
1495       reorder_memory_ = new memory(op_md, engine);
1496       auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1497       std::shared_ptr<stream> cpu_stream;
1498       cpu_stream.reset(CreateStream(context, prim->GetEngine()));
1499       std::vector<primitive> net;
1500       net.push_back(*(prim->GetPrimitive()));
1501       std::vector<MemoryArgsMap> net_args;
1502       net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1503                           {MKLDNN_ARG_TO, *reorder_memory_}});
1504       execute_primitives(net, cpu_stream, net_args);
1505       return true;
1506     }
1507     return false;
1508   }
1509 
1510   /// Overloaded version of above function that accepts memory buffer
1511   /// where output of reorder needs to be stored.
1512   ///
1513   /// @input: op_pd - memory primitive descriptor (memory descriptor for v1.x)
1514   ///                 of the given input of an operation
1515   /// @reorder_data_handle - memory buffer where output of reorder needs to be
1516   ///                        stored. Primitive does not check if buffer has
1517   ///                        enough size to write.
1518   /// @input: net - net to which to add reorder primitive in case it is needed.
1519   /// @input: net_args - net to which user and reorder memories are added if
1520   ///                    needed. Each entry is a key-value pair of the form
1521   ///                    <argument-type, mkldnn::memory>.
1522   /// @input: engine - MKL-DNN's abstraction of a computational device
1523   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,void * reorder_data_handle,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1524   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1525                                   void* reorder_data_handle,
1526                                   std::vector<primitive>& net,
1527                                   std::vector<MemoryArgsMap>& net_args,
1528                                   const engine& engine) {
1529     DCHECK(reorder_data_handle);
1530     DCHECK(user_memory_);
1531     if (IsReorderNeeded(op_md)) {
1532       // TODO(nhasabni): can we remove dynamic memory allocation?
1533       reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1534       net.push_back(CreateReorder(user_memory_, reorder_memory_));
1535       net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1536                                        {MKLDNN_ARG_TO, *reorder_memory_}});
1537       return true;
1538     }
1539     return false;
1540   }
1541 
1542   /// This is a faster path with reorder primitive cache compared with
1543   /// CheckReorderToOpMem(..., std::vector<primitive>* net).
1544   /// The slower path will be removed in the future
1545   /// TODO(bhavanis): Need to use reorder cache here for better performance.
1546   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1547                                   void* reorder_data_handle,
1548                                   const engine& engine,
1549                                   OpKernelContext* context = nullptr) {
1550     DCHECK(reorder_data_handle);
1551     DCHECK(user_memory_);
1552     if (IsReorderNeeded(op_md)) {
1553       // TODO(nhasabni): can we remove dynamic memory allocation?
1554       // primitive reuse don't allow two same reorder prim in
1555       // one stream, so submit it immediately
1556       reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1557       auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1558       std::shared_ptr<stream> cpu_stream;
1559       cpu_stream.reset(CreateStream(context, prim->GetEngine()));
1560       std::vector<primitive> net;
1561       net.push_back(*(prim->GetPrimitive()));
1562       std::vector<MemoryArgsMap> net_args;
1563       net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1564                           {MKLDNN_ARG_TO, *reorder_memory_}});
1565       execute_primitives(net, cpu_stream, net_args);
1566       return true;
1567     }
1568     return false;
1569   }
1570 
1571   /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1572   /// where output of reorder needs to be stored.
1573   ///
1574   /// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
1575   ///                 of the given input of an operation
1576   /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1577   ///                   reorder. Primitive does not check if buffer is
1578   ///                   enough size to write.
1579   /// @input: net - net to which to add reorder primitive in case it is needed.
1580   /// @input: net_args - net to which user and reorder memories are added if
1581   ///                    needed. Each entry is a key-value pair of the form
1582   ///                    <argument-type, mkldnn::memory>.
1583   /// @input: engine - MKL-DNN's abstraction of a computational device
1584   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,Tensor * reorder_tensor,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1585   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1586                                   Tensor* reorder_tensor,
1587                                   std::vector<primitive>& net,
1588                                   std::vector<MemoryArgsMap>& net_args,
1589                                   const engine& engine) {
1590     DCHECK(reorder_tensor);
1591     return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
1592                                net_args, engine);
1593   }
1594 
1595   /// TODO: this is a faster path with reorder primitive cache compared with
1596   /// CheckReorderToOpMem(op_md, reorder_tensor, net, net_args, engine), will
1597   /// remove
1598   /// slow path in the future
1599   inline bool CheckReorderToOpMem(const memory::desc& op_pd,
1600                                   Tensor* reorder_tensor,
1601                                   OpKernelContext* ctx = nullptr) {
1602     DCHECK(reorder_tensor);
1603     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
1604                                *cpu_engine_, ctx);
1605   }
1606 
1607   /// Function to handle output reorder
1608   ///
1609   /// This function performs very similar functionality as input reordering
1610   /// function above. The only difference is that this function does not add
1611   /// reorder primitive to the net. The reason for this is: the reorder
1612   /// primitive for output needs to be added to the list only after operation
1613   /// has executed. But we need to prepare a temporary buffer in case output
1614   /// reorder is needed. And this temporary buffer will hold the output of
1615   /// an operation before it is fed to reorder primitive.
1616   ///
1617   /// @input - memory primitive descriptor (memory descriptor for v1.x) for the
1618   ///          given output of an operation
1619   /// @return: true in case reorder of output is needed; false, otherwise.
PrepareReorderToUserMemIfReq(const memory::desc & op_pd)1620   inline bool PrepareReorderToUserMemIfReq(const memory::desc& op_pd) {
1621     DCHECK(user_memory_);
1622     if (IsReorderNeeded(op_pd)) {
1623       // TODO(nhasabni): can we remove dynamic memory allocation?
1624       reorder_memory_ = new memory(op_pd, *cpu_engine_);
1625       return true;
1626     }
1627     return false;
1628   }
1629 
1630   /// Function to actually insert reorder primitive in the net
1631   ///
1632   /// This function completes remaining part of output reordering. It inserts
1633   /// a reordering primitive from the temporary buffer that holds the output
1634   /// to the user-specified output buffer.
1635   ///
1636   /// @input: net - net to which to add reorder primitive
1637   /// @input: net_args - net to which user and reorder memories are added if
1638   ///                    needed. Each entry is a key-value pair of the form
1639   ///                    <argument-type, mkldnn::memory>.
InsertReorderToUserMem(std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args)1640   inline void InsertReorderToUserMem(std::vector<primitive>& net,
1641                                      std::vector<MemoryArgsMap>& net_args) {
1642     DCHECK(user_memory_);
1643     DCHECK(reorder_memory_);
1644     net.push_back(CreateReorder(reorder_memory_, user_memory_));
1645     net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
1646                                      {MKLDNN_ARG_TO, *user_memory_}});
1647   }
1648 
1649   /// TODO: this is a faster path with reorder primitive cache compared with
1650   ///       InsertReorderToUserMem(net, net_args), will remove
1651   ///       slow path in the future
1652   inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
1653     DCHECK(user_memory_);
1654     DCHECK(reorder_memory_);
1655     DCHECK(cpu_engine_);
1656     // primitive reuse don't allow two same reorder prim in
1657     // one stream, so submit it immediately
1658     std::vector<primitive> net;
1659     auto* prim = FindOrCreateReorder<T>(reorder_memory_, user_memory_);
1660     net.push_back(*(prim->GetPrimitive()));
1661     std::vector<MemoryArgsMap> net_args;
1662     net_args.push_back(
1663         {{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
1664     std::shared_ptr<stream> cpu_stream;
1665     cpu_stream.reset(CreateStream(ctx, prim->GetEngine()));
1666     execute_primitives(net, cpu_stream, net_args);
1667   }
1668 };
1669 
1670 /// Base class for operations with reuse of primitives
1671 class MklPrimitive {
1672  public:
~MklPrimitive()1673   virtual ~MklPrimitive() {}
MklPrimitive()1674   MklPrimitive() {}
MklPrimitive(const engine & cpu_engine)1675   MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
1676   // Dummy data which MKL DNN never operates on
1677   unsigned char* DummyData = nullptr;
1678   engine cpu_engine_ = engine(engine::kind::cpu, 0);
GetEngine()1679   const engine& GetEngine() { return cpu_engine_; }
1680 };
1681 
1682 const mkldnn::memory::dims NONE_DIMS = {};
1683 
1684 //
1685 // LRUCache is a class which implements LRU (Least Recently Used) cache.
1686 // The implementation is similar to that of
1687 //    tensorflow/core/platform/cloud/expiring_lru_cache.h
1688 // without its thread-safe part because the cache is supposed to be
1689 // used as thread local (for instance, MklPrimitive caching).
1690 //
1691 // The LRU list maintains objects in chronological order based on
1692 // creation time, with the least recently accessed object at the
1693 // tail of LRU list, while the most recently accessed object
1694 // at the head of LRU list.
1695 //
1696 // This class is used to maintain an upper bound on the total number of
1697 // cached items. When the cache reaches its capacity, the LRU item will
1698 // be removed and replaced by a new one from SetOp call.
1699 //
1700 template <typename T>
1701 class LRUCache {
1702  public:
LRUCache(size_t capacity)1703   explicit LRUCache(size_t capacity) {
1704     capacity_ = capacity;
1705     Clear();
1706   }
1707 
GetOp(const string & key)1708   T* GetOp(const string& key) {
1709     auto it = cache_.find(key);
1710     if (it == cache_.end()) {
1711       return nullptr;
1712     }
1713 
1714     // Move to the front of LRU list as the most recently accessed.
1715     lru_list_.erase(it->second.lru_iterator);
1716     lru_list_.push_front(it->first);
1717     it->second.lru_iterator = lru_list_.begin();
1718     return it->second.op;
1719   }
1720 
SetOp(const string & key,T * op)1721   void SetOp(const string& key, T* op) {
1722     if (lru_list_.size() >= capacity_) {
1723       Delete();
1724     }
1725 
1726     // Insert an entry to the front of the LRU list
1727     lru_list_.push_front(key);
1728     Entry entry(op, lru_list_.begin());
1729     cache_.emplace(std::make_pair(key, std::move(entry)));
1730   }
1731 
Clear()1732   void Clear() {
1733     if (lru_list_.empty()) return;
1734 
1735     // Clean up the cache
1736     cache_.clear();
1737     lru_list_.clear();
1738   }
1739 
1740  private:
1741   struct Entry {
1742     // The entry's value.
1743     T* op;
1744 
1745     // A list iterator pointing to the entry's position in the LRU list.
1746     std::list<string>::iterator lru_iterator;
1747 
1748     // Constructor
EntryEntry1749     Entry(T* op, std::list<string>::iterator it) {
1750       this->op = op;
1751       this->lru_iterator = it;
1752     }
1753 
1754     // Move constructor
EntryEntry1755     Entry(Entry&& source) noexcept
1756         : lru_iterator(std::move(source.lru_iterator)) {
1757       op = std::move(source.op);
1758       source.op = std::forward<T*>(nullptr);
1759     }
1760 
1761     // Destructor
~EntryEntry1762     ~Entry() {
1763       if (op != nullptr) delete op;
1764     }
1765   };
1766 
1767   // Remove the least recently accessed entry from LRU list, which
1768   // is the tail of lru_list_. Update cache_ correspondingly.
Delete()1769   bool Delete() {
1770     if (lru_list_.empty()) return false;
1771     string key = lru_list_.back();
1772     lru_list_.pop_back();
1773     cache_.erase(key);
1774     return true;
1775   }
1776 
1777   // Cache capacity
1778   size_t capacity_;
1779 
1780   // The cache, a map from string key to a LRU entry.
1781   std::unordered_map<string, Entry> cache_;
1782 
1783   // The LRU list of entries.
1784   // The front of the list contains the key of the most recently accessed
1785   // entry, while the back of the list is the least recently accessed entry.
1786   std::list<string> lru_list_;
1787 };
1788 
1789 template <typename T>
1790 class MklPrimitiveFactory {
1791  public:
MklPrimitiveFactory()1792   MklPrimitiveFactory() {}
1793 
~MklPrimitiveFactory()1794   ~MklPrimitiveFactory() {}
1795 
GetOp(const string & key)1796   MklPrimitive* GetOp(const string& key) {
1797     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1798     return lru_cache.GetOp(key);
1799   }
1800 
SetOp(const string & key,MklPrimitive * op)1801   void SetOp(const string& key, MklPrimitive* op) {
1802     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1803     lru_cache.SetOp(key, op);
1804   }
1805 
1806   /// Function to decide whether HW has AVX512 or AVX2
1807   /// For those legacy device(w/o AVX512 and AVX2),
1808   /// MKL-DNN GEMM will be used.
IsLegacyPlatform()1809   static inline bool IsLegacyPlatform() {
1810     return (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
1811             !port::TestCPUFeature(port::CPUFeature::AVX2));
1812   }
1813 
1814   /// Function to check whether primitive memory optimization is enabled
IsPrimitiveMemOptEnabled()1815   static inline bool IsPrimitiveMemOptEnabled() {
1816     bool is_primitive_mem_opt_enabled = true;
1817     TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
1818                                    &is_primitive_mem_opt_enabled));
1819     return is_primitive_mem_opt_enabled;
1820   }
1821 
1822  private:
GetLRUCache()1823   static inline LRUCache<MklPrimitive>& GetLRUCache() {
1824     static const int kCapacity = 1024;  // cache capacity
1825     static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
1826     return lru_cache_;
1827   }
1828 };
1829 
1830 // utility class for creating keys of MKL primitive pool.
1831 class FactoryKeyCreator {
1832  public:
FactoryKeyCreator()1833   FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
1834 
~FactoryKeyCreator()1835   ~FactoryKeyCreator() {}
1836 
AddAsKey(const string & str)1837   void AddAsKey(const string& str) { Append(str); }
1838 
AddAsKey(const mkldnn::memory::dims & dims)1839   void AddAsKey(const mkldnn::memory::dims& dims) {
1840     for (unsigned int i = 0; i < dims.size(); i++) {
1841       AddAsKey<int>(dims[i]);
1842     }
1843   }
1844 
1845   template <typename T>
AddAsKey(const T data)1846   void AddAsKey(const T data) {
1847     auto buffer = reinterpret_cast<const char*>(&data);
1848     Append(StringPiece(buffer, sizeof(T)));
1849   }
1850 
GetKey()1851   string GetKey() { return key_; }
1852 
1853  private:
1854   string key_;
1855   const char delimiter = 'x';
1856   const int kMaxKeyLength = 256;
Append(StringPiece s)1857   void Append(StringPiece s) {
1858     key_.append(string(s));
1859     key_.append(1, delimiter);
1860   }
1861 };
1862 
1863 class MklReorderPrimitive : public MklPrimitive {
1864  public:
MklReorderPrimitive(const memory * from,const memory * to)1865   explicit MklReorderPrimitive(const memory* from, const memory* to)
1866       : MklPrimitive(engine(engine::kind::cpu, 0)) {
1867     Setup(from, to);
1868   }
~MklReorderPrimitive()1869   ~MklReorderPrimitive() {}
1870 
GetPrimitive()1871   std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
1872 
SetMemory(const memory * from,const memory * to)1873   void SetMemory(const memory* from, const memory* to) {
1874     context_.src_mem->set_data_handle(from->get_data_handle());
1875     context_.dst_mem->set_data_handle(to->get_data_handle());
1876   }
1877 
GetStream()1878   std::shared_ptr<mkldnn::stream> GetStream() { return stream_; }
1879 
1880  private:
1881   struct ReorderContext {
1882     std::shared_ptr<mkldnn::memory> src_mem;
1883     std::shared_ptr<mkldnn::memory> dst_mem;
1884     std::shared_ptr<primitive> reorder_prim;
ReorderContextReorderContext1885     ReorderContext()
1886         : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
1887   } context_;
1888 
1889   std::shared_ptr<mkldnn::stream> stream_;
1890 
Setup(const memory * from,const memory * to)1891   void Setup(const memory* from, const memory* to) {
1892     context_.src_mem.reset(
1893         new memory(from->get_desc(), cpu_engine_, DummyData));
1894     context_.dst_mem.reset(new memory(to->get_desc(), cpu_engine_, DummyData));
1895     context_.reorder_prim = std::make_shared<mkldnn::reorder>(
1896         reorder(*context_.src_mem, *context_.dst_mem));
1897     stream_.reset(new stream(cpu_engine_));
1898   }
1899 };
1900 
1901 template <typename T>
1902 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
1903  public:
Get(const memory * from,const memory * to)1904   static MklReorderPrimitive* Get(const memory* from, const memory* to) {
1905     auto reorderPrim = static_cast<MklReorderPrimitive*>(
1906         MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
1907     if (reorderPrim == nullptr) {
1908       reorderPrim = new MklReorderPrimitive(from, to);
1909       MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
1910                                                               reorderPrim);
1911     }
1912     reorderPrim->SetMemory(from, to);
1913     return reorderPrim;
1914   }
1915 
GetInstance()1916   static MklReorderPrimitiveFactory& GetInstance() {
1917     static MklReorderPrimitiveFactory instance_;
1918     return instance_;
1919   }
1920 
CreateKey(const memory * from,const memory * to)1921   static string CreateKey(const memory* from, const memory* to) {
1922     string prefix = "reorder";
1923     FactoryKeyCreator key_creator;
1924     auto const& from_desc = from->get_desc().data;
1925     auto const& to_desc = to->get_desc().data;
1926     memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
1927     memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
1928     auto from_strides = from_desc.format_desc.blocking.strides;
1929 
1930     // As DNNL memory desc has C style array and only init the used
1931     // part, so need use the valid part as key.
1932     auto from_inner_nblks = from_desc.format_desc.blocking.inner_nblks;
1933     auto from_inner_blks = from_desc.format_desc.blocking.inner_blks;
1934     auto from_inner_idxs = from_desc.format_desc.blocking.inner_idxs;
1935     memory::dims from_inner_blks_1(from_inner_blks,
1936                                    &from_inner_blks[from_inner_nblks]);
1937     memory::dims from_inner_idxs_1(from_inner_idxs,
1938                                    &from_inner_idxs[from_inner_nblks]);
1939     auto to_inner_nblks = to_desc.format_desc.blocking.inner_nblks;
1940     auto to_inner_blks = to_desc.format_desc.blocking.inner_blks;
1941     auto to_inner_idxs = to_desc.format_desc.blocking.inner_idxs;
1942     memory::dims to_inner_blks_1(to_inner_blks, &to_inner_blks[to_inner_nblks]);
1943     memory::dims to_inner_idxs_1(to_inner_idxs, &to_inner_idxs[to_inner_nblks]);
1944 
1945     auto to_strides = to_desc.format_desc.blocking.strides;
1946     memory::dims from_strides_outer_blocks(from_strides,
1947                                            &from_strides[from_desc.ndims]);
1948     memory::dims to_strides_outer_blocks(to_strides,
1949                                          &to_strides[to_desc.ndims]);
1950 
1951     key_creator.AddAsKey(prefix);
1952     key_creator.AddAsKey(static_cast<int>(from_desc.extra.flags));
1953     key_creator.AddAsKey(static_cast<int>(from_inner_nblks));
1954     key_creator.AddAsKey(from_inner_blks_1);
1955     key_creator.AddAsKey(from_inner_idxs_1);
1956     key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
1957     key_creator.AddAsKey(from_dims);
1958     key_creator.AddAsKey(from_strides_outer_blocks);
1959     key_creator.AddAsKey(static_cast<int>(to_desc.extra.flags));
1960     key_creator.AddAsKey(static_cast<int>(to_inner_nblks));
1961     key_creator.AddAsKey(to_inner_blks_1);
1962     key_creator.AddAsKey(to_inner_idxs_1);
1963     key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
1964     key_creator.AddAsKey(to_dims);
1965     key_creator.AddAsKey(to_strides_outer_blocks);
1966     return key_creator.GetKey();
1967   }
1968 
1969  private:
MklReorderPrimitiveFactory()1970   MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory()1971   ~MklReorderPrimitiveFactory() {}
1972 
GetReorder(const memory * from,const memory * to)1973   MklPrimitive* GetReorder(const memory* from, const memory* to) {
1974     string key = CreateKey(from, to);
1975     return this->GetOp(key);
1976   }
1977 
SetReorder(const memory * from,const memory * to,MklPrimitive * op)1978   void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
1979     string key = CreateKey(from, to);
1980     this->SetOp(key, op);
1981   }
1982 };
1983 
1984 /// Function to find(or create) a reorder from memory pointed by
1985 /// from to memory pointed by to, it will created primitive or
1986 /// get primitive from pool if it is cached.
1987 /// Returns the primitive.
1988 template <typename T>
FindOrCreateReorder(const memory * from,const memory * to)1989 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1990                                                 const memory* to) {
1991   CHECK_NOTNULL(from);
1992   CHECK_NOTNULL(to);
1993   MklReorderPrimitive* reorder_prim =
1994       MklReorderPrimitiveFactory<T>::Get(from, to);
1995   return reorder_prim;
1996 }
1997 
1998 // utility function to determine if it is conv 1x1 and stride != 1
1999 // for purpose of temporarily disabling primitive reuse
IsConv1x1StrideNot1(memory::dims filter_dims,memory::dims strides)2000 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2001                                 memory::dims strides) {
2002   if (filter_dims.size() != 4 || strides.size() != 2) return false;
2003 
2004   return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2005           ((strides[0] != 1) || (strides[1] != 1)));
2006 }
2007 
2008 }  // namespace tensorflow
2009 
2010 /////////////////////////////////////////////////////////////////////
2011 // Macros for handling registration for various types
2012 /////////////////////////////////////////////////////////////////////
2013 
2014 #define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
2015 
2016 #ifdef ENABLE_INTEL_MKL_BFLOAT16
2017 #define REGISTER_TEST_BFLOAT16(TEST) \
2018   REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
2019 
2020 #define REGISTER_TEST_ALL_TYPES(TEST) \
2021   REGISTER_TEST_FLOAT32(TEST);        \
2022   REGISTER_TEST_BFLOAT16(TEST);
2023 #else
2024 #define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
2025 #endif  // ENABLE_INTEL_MKL_BFLOAT16
2026 
2027 #endif  // INTEL_MKL
2028 #endif  // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2029