1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18 #ifdef INTEL_MKL
19
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26
27 #if defined(INTEL_MKL_ML_ONLY) || defined(INTEL_MKL_DNN_ONLY)
28 #ifndef INTEL_MKL
29 #error "INTEL_MKL_{ML,DNN}_ONLY require INTEL_MKL"
30 #endif
31 #endif
32
33 #if defined(INTEL_MKL_ML_ONLY) && defined(INTEL_MKL_DNN_ONLY)
34 #error "at most one of INTEL_MKL_ML_ONLY and INTEL_MKL_DNN_ONLY may be defined"
35 #endif
36
37 #ifdef INTEL_MKL_ML_ONLY
38 #error "Please use INTEL MKL DNN (the default option for --config=mkl)."
39 #endif
40
41 #ifdef INTEL_MKL_ML_ONLY
42 #include "mkl_dnn.h"
43 #include "mkl_dnn_types.h"
44 #include "mkl_service.h"
45 #include "mkl_trans.h"
46 #endif
47
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor_shape.h"
51 #include "tensorflow/core/graph/mkl_graph_util.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/gtl/array_slice.h"
54 #include "tensorflow/core/platform/cpu_info.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/util/env_var.h"
58 #include "tensorflow/core/util/padding.h"
59 #include "tensorflow/core/util/tensor_format.h"
60
61 #ifndef INTEL_MKL_ML_ONLY
62 #include "mkldnn.hpp"
63 #include "tensorflow/core/lib/core/stringpiece.h"
64
65 using mkldnn::engine;
66 using mkldnn::memory;
67 using mkldnn::padding_kind;
68 using mkldnn::primitive;
69 using mkldnn::reorder;
70 #endif
71
72 #ifdef _WIN32
73 typedef unsigned int uint;
74 #endif
75
76 namespace tensorflow {
77
78 // The file contains a number of utility classes and functions used by MKL
79 // enabled kernels
80
81 // This class encapsulates all the meta data that is associated with an MKL
82 // tensor. A tensor is an MKL tensor if it was created as the result of an
83 // MKL operation, and did not go through a conversion to a standard
84 // Tensorflow tensor.
85
86 // For use with MKL ML, has been deprecated
87 typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims;
88
89 // The dimensions order that MKL-DNN internally uses for 2D activations
90 // [Batch, Channel, Height, Width] and
91 // for 2D filters [Out_Channel, In_Channel, Height, Width].
92 typedef enum {
93 Dim_N = 0,
94 Dim_C = 1,
95 Dim_H = 2,
96 Dim_W = 3,
97 Dim_O = 0,
98 Dim_I = 1
99 } MklDnnDims;
100
101 // The dimensions order that MKL-DNN internally uses for 3D activations
102 // [Batch, Channel, Depth, Height, Width] and
103 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
104 typedef enum {
105 Dim3d_N = 0,
106 Dim3d_C = 1,
107 Dim3d_D = 2,
108 Dim3d_H = 3,
109 Dim3d_W = 4,
110 Dim3d_O = 0,
111 Dim3d_I = 1
112 } MklDnnDims3D;
113
114 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
115 // filter_width, in_channels, out_channels]
116 typedef enum {
117 TF_2DFILTER_DIM_H = 0,
118 TF_2DFILTER_DIM_W = 1,
119 TF_2DFILTER_DIM_I = 2,
120 TF_2DFILTER_DIM_O = 3
121 } TFFilterDims2d;
122
123 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
124 // filter_height, filter_width, in_channels, out_channels]
125 typedef enum {
126 TF_3DFILTER_DIM_P = 0,
127 TF_3DFILTER_DIM_H = 1,
128 TF_3DFILTER_DIM_W = 2,
129 TF_3DFILTER_DIM_I = 3,
130 TF_3DFILTER_DIM_O = 4
131 } TFFilterDims3d;
132
133 // The dimensions order that MKL-DNN requires for the filter in a grouped
134 // convolution (2D only)
135 typedef enum {
136 MKL_GROUP_FILTER_DIM_G = 0,
137 MKL_GROUP_FILTER_DIM_O = 1,
138 MKL_GROUP_FILTER_DIM_I = 2,
139 MKL_GROUP_FILTER_DIM_H = 3,
140 MKL_GROUP_FILTER_DIM_W = 4
141 } MklDnnFilterGroupDims;
142
143 // Enum used to templatize MklOp kernel implementations
144 // that support both fp32 and int8 versions.
145 enum class MklQuantization {
146 QUANTIZED_VERSION,
147 FP_VERSION,
148 };
149
150 static const int kSmallBatchSize = 32;
151
152 #ifdef INTEL_MKL_ML_ONLY
153 class MklShape {
154 public:
MklShape()155 MklShape() {}
156 TF_DISALLOW_COPY_AND_ASSIGN(MklShape); // Cannot copy
157
~MklShape()158 ~MklShape() {
159 if (sizes_) delete[] sizes_;
160 if (strides_) delete[] strides_;
161 if (mklLayout_) CHECK_EQ(dnnLayoutDelete_F32(mklLayout_), E_SUCCESS);
162 if (tfLayout_) CHECK_EQ(dnnLayoutDelete_F32(tfLayout_), E_SUCCESS);
163 if (tf_to_mkl_dim_map_) delete[] tf_to_mkl_dim_map_;
164 }
165
IsMklTensor()166 const bool IsMklTensor() const { return isMklTensor_; }
167
SetMklTensor(const bool isMklTensor)168 void SetMklTensor(const bool isMklTensor) { isMklTensor_ = isMklTensor; }
169
SetDimensions(const size_t dimension)170 void SetDimensions(const size_t dimension) { dimension_ = dimension; }
171
SetMklLayout(dnnLayout_t mklLayout)172 void SetMklLayout(dnnLayout_t mklLayout) { mklLayout_ = mklLayout; }
173
SetMklLayout(const void * primitive,size_t resourceType)174 void SetMklLayout(const void* primitive, size_t resourceType) {
175 CHECK_EQ(
176 dnnLayoutCreateFromPrimitive_F32(&mklLayout_, (dnnPrimitive_t)primitive,
177 (dnnResourceType_t)resourceType),
178 E_SUCCESS);
179 }
180
SetTfLayout(const size_t dimension,const size_t * sizes,const size_t * strides)181 void SetTfLayout(const size_t dimension, const size_t* sizes,
182 const size_t* strides) {
183 dimension_ = dimension;
184 if (dimension > 0) { // MKl doesn't support zero dimension tensors
185 sizes_ = new size_t[dimension];
186 strides_ = new size_t[dimension];
187
188 for (int ii = 0; ii < dimension; ii++) {
189 sizes_[ii] = sizes[ii];
190 strides_[ii] = strides[ii];
191 }
192 CHECK_EQ(dnnLayoutCreate_F32(&tfLayout_, dimension, sizes, strides),
193 E_SUCCESS);
194 }
195 }
196
197 // Default case - MKL dim ordering is opposite of TF dim ordering
198 // MKL -> (DIMS-1)...0 where (DIMS-1) is outermost dim and 0 is innermost dim
199 // TF -> 0...(DIMS-1) where 0 is outermost dim and (DIMS-1) is innermost dim
200 // For layers that rely on data_format semantics (conv, pooling etc.)
201 // or operate only on certain dimensions (relu, concat, split etc.),
202 // Mkl APIs might require us to reorder these dimensions. In such cases,
203 // kernels should explicitly set this map
SetTfDimOrder(const size_t dimension)204 void SetTfDimOrder(const size_t dimension) {
205 CHECK(dimension == dimension_);
206 if (tf_to_mkl_dim_map_ == nullptr) {
207 tf_to_mkl_dim_map_ = new size_t[dimension];
208 }
209 for (size_t ii = 0; ii < dimension; ii++) {
210 tf_to_mkl_dim_map_[ii] = dimension - (ii + 1);
211 }
212 }
213
SetTfDimOrder(const size_t dimension,const size_t * tf_to_mkl_dim_map)214 void SetTfDimOrder(const size_t dimension, const size_t* tf_to_mkl_dim_map) {
215 CHECK(dimension == dimension_);
216 if (tf_to_mkl_dim_map_ == nullptr) {
217 tf_to_mkl_dim_map_ = new size_t[dimension];
218 }
219 for (size_t ii = 0; ii < dimension; ii++) {
220 tf_to_mkl_dim_map_[ii] = tf_to_mkl_dim_map[ii];
221 }
222 }
223
SetTfDimOrder(const size_t dimension,TensorFormat data_format)224 void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
225 CHECK_EQ(dimension, 4);
226 CHECK(dimension == dimension_);
227 if (tf_to_mkl_dim_map_ == nullptr) {
228 tf_to_mkl_dim_map_ = new size_t[dimension];
229 }
230 tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDims::W;
231 tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDims::H;
232 tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDims::C;
233 tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDims::N;
234 }
235
GetMklLayout()236 const dnnLayout_t GetMklLayout() const { return mklLayout_; }
GetTfLayout()237 const dnnLayout_t GetTfLayout() const { return tfLayout_; }
GetCurLayout()238 const dnnLayout_t GetCurLayout() const {
239 return isMklTensor_ ? mklLayout_ : tfLayout_;
240 }
GetDimension()241 size_t GetDimension() const { return dimension_; }
GetSizes()242 const size_t* GetSizes() const { return sizes_; }
dim_size(int index)243 int64 dim_size(int index) const { return sizes_[index]; }
tf_dim_size(int index)244 int64 tf_dim_size(int index) const {
245 return sizes_[tf_to_mkl_dim_map_[index]];
246 }
GetStrides()247 const size_t* GetStrides() const { return strides_; }
GetTfToMklDimMap()248 const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
tf_dim_idx(int index)249 size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
250
251 // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
252 // corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)253 bool IsMklChannelDim(int d) const { return tf_dim_idx(d) == MklDims::C; }
254 // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
255 // corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)256 bool IsMklBatchDim(int d) const { return tf_dim_idx(d) == MklDims::N; }
257 // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
258 // corresponds to MKL's Width dimension.
IsMklWidthDim(int d)259 bool IsMklWidthDim(int d) const { return tf_dim_idx(d) == MklDims::W; }
260 // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
261 // corresponds to MKL's Height dimension.
IsMklHeightDim(int d)262 bool IsMklHeightDim(int d) const { return tf_dim_idx(d) == MklDims::H; }
263
264 // Check if the TF-Mkl dimension ordering map specifies if the input
265 // tensor is in NCHW format.
IsTensorInNCHWFormat()266 bool IsTensorInNCHWFormat() const {
267 TensorFormat data_format = FORMAT_NCHW;
268 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
269 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
270 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
271 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
272 }
273
274 // Check if the TF-Mkl dimension ordering map specifies if the input
275 // tensor is in NHWC format.
IsTensorInNHWCFormat()276 bool IsTensorInNHWCFormat() const {
277 TensorFormat data_format = FORMAT_NHWC;
278 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
279 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
280 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
281 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
282 }
283
GetConvertedFlatData(dnnLayout_t targetLayout,void * input,void * output)284 void GetConvertedFlatData(dnnLayout_t targetLayout, void* input,
285 void* output) const {
286 dnnLayout_t curLayout;
287 if (isMklTensor_)
288 curLayout = mklLayout_;
289 else
290 curLayout = tfLayout_;
291 dnnPrimitive_t convert;
292 CHECK_EQ(dnnConversionCreate_F32(&convert, curLayout, targetLayout),
293 E_SUCCESS);
294 CHECK_EQ(dnnConversionExecute_F32(convert, input, output), E_SUCCESS);
295 CHECK_EQ(dnnDelete_F32(convert), E_SUCCESS);
296 }
297
298 // The following methods are used for serializing and de-serializing the
299 // contents of the mklshape object.
300 // The data is serialized in this order
301 // isMklTensor_
302 // dimension_
303 // sizes_
304 // strides_
305 // mklLayout_
306 // tfLayout_
307 // tf_to_mkl_dim_map_
308
309 #define SIZE_OF_MKL_DNN_BUF \
310 (dnnLayoutSerializationBufferSize_F32()) // Size of buffer needed to
311 // serialize dnn_layout pointer
312
313 // Size of buffer to hold the serialized object, the size is computed as
314 // follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) +
315 // sizeof(strides_)
316 // + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer)
317 // + sizeof(tf_to_mkl_dim_map_)
318
319 #define SIZE_OF_MKL_SERIAL_DATA(dims) \
320 (2 * sizeof(size_t) + 3 * dims * sizeof(size_t) + 2 * SIZE_OF_MKL_DNN_BUF)
321
322 // First we need to define some macro for offsets into the serial buffer where
323 // different elements of Mklshape is written/read from
324
325 #define IS_MKL_TENSOR_OFFSET 0
326 // Location from start of buffer where isMklTensor_ is serialized
327 #define DIMS_OFFSET \
328 (IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_
329 // Location of sizes. Note dim is not used here, left here
330 // to make macros consistent.
331 #define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
332 #define STRIDES_OFFSET(dims) \
333 (SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides
334 #define MKL_LAYOUT_OFFSET(dims) \
335 (STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_
336 #define TF_LAYOUT_OFFSET(dims) \
337 (MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_
338 // Location of tf_to_mkl_dim_map_
339 #define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
340 (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
341
342 // TODO(agramesh1) make sure to create a const to share with rewrite pass
343 // for min size of MKL metadata tensor.
344
DeSerializeMklShape(const unsigned char * buf,size_t buf_size)345 void DeSerializeMklShape(const unsigned char* buf, size_t buf_size) {
346 CHECK(buf_size >= sizeof(size_t)) << "Bufsize too small in DeSerialize";
347 // Make sure buffer holds at least isMklTensor_
348 isMklTensor_ =
349 *reinterpret_cast<const size_t*>(buf + IS_MKL_TENSOR_OFFSET) != 0;
350
351 if (isMklTensor_) { // If it is an MKL Tensor then read the rest
352 dimension_ = *(reinterpret_cast<const size_t*>(buf + DIMS_OFFSET));
353 CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
354 << "Bufsize too small in DeSerialize";
355 sizes_ = new size_t[dimension_];
356 strides_ = new size_t[dimension_];
357 tf_to_mkl_dim_map_ = new size_t[dimension_];
358 for (int i = 0; i < dimension_; i++) {
359 sizes_[i] =
360 reinterpret_cast<const size_t*>(buf + SIZES_OFFSET(dimension_))[i];
361 strides_[i] = reinterpret_cast<const size_t*>(
362 buf + STRIDES_OFFSET(dimension_))[i];
363 tf_to_mkl_dim_map_[i] = reinterpret_cast<const size_t*>(
364 buf + TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i];
365 }
366 CHECK_EQ(dnnLayoutDeserialize_F32(&mklLayout_,
367 buf + MKL_LAYOUT_OFFSET(dimension_)),
368 E_SUCCESS);
369 CHECK_EQ(dnnLayoutDeserialize_F32(&tfLayout_,
370 buf + TF_LAYOUT_OFFSET(dimension_)),
371 E_SUCCESS);
372 }
373 }
374
SerializeMklShape(unsigned char * buf,size_t buf_size)375 void SerializeMklShape(unsigned char* buf, size_t buf_size) const {
376 CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
377 << "Bufsize too small to Serialize";
378 *reinterpret_cast<size_t*>(buf + IS_MKL_TENSOR_OFFSET) =
379 isMklTensor_ ? 1 : 0;
380 if (isMklTensor_) {
381 *(reinterpret_cast<size_t*>(buf + DIMS_OFFSET)) = dimension_;
382 for (int i = 0; i < dimension_; i++) {
383 reinterpret_cast<size_t*>(buf + SIZES_OFFSET(dimension_))[i] =
384 sizes_[i];
385 reinterpret_cast<size_t*>(buf + STRIDES_OFFSET(dimension_))[i] =
386 strides_[i];
387 reinterpret_cast<size_t*>(buf +
388 TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i] =
389 tf_to_mkl_dim_map_[i];
390 }
391 CHECK_EQ(dnnLayoutSerialize_F32(mklLayout_,
392 buf + MKL_LAYOUT_OFFSET(dimension_)),
393 E_SUCCESS);
394 CHECK_EQ(
395 dnnLayoutSerialize_F32(tfLayout_, buf + TF_LAYOUT_OFFSET(dimension_)),
396 E_SUCCESS);
397 }
398 }
399
400 private:
401 bool isMklTensor_ =
402 false; // Flag to indicate if the tensor is an MKL tensor or not
403 dnnLayout_t mklLayout_ = nullptr; // Pointer to the MKL layout
404 dnnLayout_t tfLayout_ = nullptr; // Pointer to layout of corresponding
405 // Tensorflow tensor, used when conversion from MKL to standard tensor
406 size_t dimension_ = 0;
407 size_t* sizes_ = nullptr; // Required by MKL for conversions
408 size_t* strides_ = nullptr; // Required by MKL for conversions
409 size_t* tf_to_mkl_dim_map_ =
410 nullptr; // TF dimension corresponding to this MKL dimension
411 };
412
413 #else
414
415 // Forward decl
416 TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
417 TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
418 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
419 memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
420 const memory::dims& strides,
421 memory::data_type dtype);
422
423 class MklDnnShape {
424 private:
425 typedef struct {
426 /// Flag to indicate if the tensor is an MKL tensor or not
427 bool is_mkl_tensor_ = false;
428 /// Number of dimensions in Tensorflow format
429 size_t dimension_ = 0;
430 /// Required by MKLDNN for conversions
431 mkldnn_dims_t sizes_; // Required by MKL for conversions
432 memory::format tf_data_format_ = memory::format::format_undef;
433 memory::data_type T_ = memory::data_type::data_undef;
434 // MKL layout
435 mkldnn_memory_desc_t mkl_md_;
436 /// TF dimension corresponding to this MKL dimension
437 mkldnn_dims_t map_;
438 } MklShapeData;
439 MklShapeData data_;
440
441 typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
442 #define INVALID_DIM_SIZE -1
443
444 public:
MklDnnShape()445 MklDnnShape() {
446 for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
447 ++i) {
448 data_.sizes_[i] = -1;
449 }
450 for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
451 data_.map_[i] = -1;
452 }
453 }
454
~MklDnnShape()455 ~MklDnnShape() {}
456 TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape); // Cannot copy
457
458 /// Helper function to compare memory::desc objects for MklDnn.
459 /// May be this should go into MklDnn directly.
CompareMklDnnLayouts(const memory::desc & md1,const memory::desc & md2)460 inline bool CompareMklDnnLayouts(const memory::desc& md1,
461 const memory::desc& md2) const {
462 mkldnn_memory_desc_t mdd1 = md1.data;
463 mkldnn_memory_desc_t mdd2 = md2.data;
464 const char* d1 = reinterpret_cast<const char*>(&mdd1);
465 const char* d2 = reinterpret_cast<const char*>(&mdd2);
466
467 size_t md_size = sizeof(mdd1);
468 for (size_t i = 0; i < md_size; i++) {
469 if (*d1++ != *d2++) {
470 return false;
471 }
472 }
473 return true;
474 }
475
476 /// Equality function for MklDnnShape objects
477 /// @return true if both are equal; false otherwise.
478 inline bool operator==(const MklDnnShape& input_shape) const {
479 if (this->IsMklTensor() != input_shape.IsMklTensor()) {
480 return false;
481 }
482
483 // If input tensors are in Mkl layout, then we check for dimensions and
484 // sizes.
485 if (this->IsMklTensor()) {
486 return this->GetTfShape() == input_shape.GetTfShape() &&
487 CompareMklDnnLayouts(this->GetMklLayout(),
488 input_shape.GetMklLayout());
489 }
490
491 return true;
492 }
493
494 /// Equality operator for MklDnnShape and TFShape.
495 /// Returns: true if TF shapes for both are the same, false otherwise
496 inline bool operator==(const TensorShape& input_shape) const {
497 if (!this->IsMklTensor()) {
498 return false;
499 }
500
501 return this->GetTfShape() == input_shape;
502 }
503
IsMklTensor()504 inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
SetMklTensor(bool is_mkl_tensor)505 inline void SetMklTensor(bool is_mkl_tensor) {
506 data_.is_mkl_tensor_ = is_mkl_tensor;
507 }
508
SetDimensions(const size_t dimension)509 inline void SetDimensions(const size_t dimension) {
510 data_.dimension_ = dimension;
511 }
GetDimension(char dimension)512 inline size_t GetDimension(char dimension) const {
513 int index = GetMklDnnTensorDimIndex(dimension);
514 CHECK(index >= 0 && index < this->GetDimension())
515 << "Invalid index from the dimension: " << index << ", " << dimension;
516 return this->DimSize(index);
517 }
518
GetDimension3D(char dimension)519 inline size_t GetDimension3D(char dimension) const {
520 int index = GetMklDnnTensor3DDimIndex(dimension);
521 CHECK(index >= 0 && index < this->GetDimension())
522 << "Invalid index from the dimension: " << index << ", " << dimension;
523 return this->DimSize(index);
524 }
525
GetMklDnnTensorDimIndex(char dimension)526 inline int32 GetMklDnnTensorDimIndex(char dimension) const {
527 switch (dimension) {
528 case 'N':
529 return MklDnnDims::Dim_N;
530 case 'C':
531 return MklDnnDims::Dim_C;
532 case 'H':
533 return MklDnnDims::Dim_H;
534 case 'W':
535 return MklDnnDims::Dim_W;
536 default:
537 LOG(FATAL) << "Invalid dimension: " << dimension;
538 return -1; // Avoid compiler warning about missing return value
539 }
540 }
541
GetMklDnnTensor3DDimIndex(char dimension)542 inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
543 switch (dimension) {
544 case 'N':
545 return MklDnnDims3D::Dim3d_N;
546 case 'C':
547 return MklDnnDims3D::Dim3d_C;
548 case 'D':
549 return MklDnnDims3D::Dim3d_D;
550 case 'H':
551 return MklDnnDims3D::Dim3d_H;
552 case 'W':
553 return MklDnnDims3D::Dim3d_W;
554 default:
555 LOG(FATAL) << "Invalid dimension: " << dimension;
556 return -1; // Avoid compiler warning about missing return value
557 }
558 }
559
GetDimension()560 inline size_t GetDimension() const { return data_.dimension_; }
GetSizes()561 inline const int* GetSizes() const {
562 return reinterpret_cast<const int*>(&data_.sizes_[0]);
563 }
564
565 // Returns an mkldnn::memory::dims object that contains the sizes of this
566 // MklDnnShape object.
GetSizesAsMklDnnDims()567 inline memory::dims GetSizesAsMklDnnDims() const {
568 memory::dims retVal;
569 if (data_.is_mkl_tensor_) {
570 size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
571 for (size_t i = 0; i < dimensions; i++) {
572 if (data_.sizes_[i] != INVALID_DIM_SIZE)
573 retVal.push_back(data_.sizes_[i]);
574 }
575 } else {
576 CHECK_EQ(data_.is_mkl_tensor_, true);
577 }
578 return retVal;
579 }
580
DimSize(int index)581 inline int64 DimSize(int index) const {
582 CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
583 return data_.sizes_[index];
584 }
585
586 /// Return TensorShape that describes the Tensorflow shape of the tensor
587 /// represented by this MklShape.
GetTfShape()588 inline TensorShape GetTfShape() const {
589 CHECK_EQ(data_.is_mkl_tensor_, true);
590
591 std::vector<int32> shape(data_.dimension_, -1);
592 if (data_.tf_data_format_ != memory::format::blocked) {
593 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
594 shape[idx] = data_.sizes_[TfDimIdx(idx)];
595 }
596 } else {
597 // If Tensorflow shape is in Blocked format, then we don't have dimension
598 // map for it. So we just create Tensorflow shape from sizes in the
599 // specified order.
600 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
601 shape[idx] = data_.sizes_[idx];
602 }
603 }
604
605 TensorShape ts;
606 bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
607 CHECK_EQ(ret, true);
608 return ts;
609 }
610
SetElemType(memory::data_type dt)611 inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
GetElemType()612 inline const memory::data_type GetElemType() { return data_.T_; }
613
SetMklLayout(memory::primitive_desc * pd)614 inline void SetMklLayout(memory::primitive_desc* pd) {
615 CHECK_NOTNULL(pd);
616 data_.mkl_md_ = pd->desc().data;
617 }
618
SetMklLayout(memory::desc * md)619 inline void SetMklLayout(memory::desc* md) {
620 CHECK_NOTNULL(md);
621 data_.mkl_md_ = md->data;
622 }
623
GetMklLayout()624 inline const memory::desc GetMklLayout() const {
625 return memory::desc(data_.mkl_md_);
626 }
627
GetTfDataFormat()628 inline memory::format GetTfDataFormat() const {
629 return data_.tf_data_format_;
630 }
631 /// We don't create primitive_descriptor for TensorFlow layout now.
632 /// We use lazy evaluation and create it only when needed. Input format can
633 /// also be Blocked format.
SetTfLayout(size_t dims,const memory::dims & sizes,memory::format format)634 inline void SetTfLayout(size_t dims, const memory::dims& sizes,
635 memory::format format) {
636 CHECK_EQ(dims, sizes.size());
637 data_.dimension_ = dims;
638 for (size_t ii = 0; ii < dims; ii++) {
639 data_.sizes_[ii] = sizes[ii];
640 }
641 data_.tf_data_format_ = format;
642 if (format != memory::format::blocked) {
643 SetTfDimOrder(dims, format);
644 }
645 }
646
GetTfLayout()647 inline const memory::desc GetTfLayout() const {
648 memory::dims dims;
649 for (size_t ii = 0; ii < data_.dimension_; ii++) {
650 dims.push_back(data_.sizes_[ii]);
651 }
652
653 // Create Blocked memory desc if input TF format was set like that.
654 if (data_.tf_data_format_ == memory::format::blocked) {
655 auto strides = CalculateTFStrides(dims);
656 return CreateBlockedMemDescHelper(dims, strides, data_.T_);
657 } else {
658 return memory::desc(dims, data_.T_, data_.tf_data_format_);
659 }
660 }
661
GetCurLayout()662 inline const memory::desc GetCurLayout() const {
663 return IsMklTensor() ? GetMklLayout() : GetTfLayout();
664 }
665
666 // nhasabni - I've removed SetTfDimOrder that was setting default order in
667 // case of MKL-ML. We don't need a case of default dimension order because
668 // when an operator that does not get data_format attribute gets all inputs
669 // in Tensorflow format, it will produce output in Tensorflow format.
SetTfDimOrder(const size_t dimension,const mkldnn_dims_t map)670 inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
671 CHECK(dimension == data_.dimension_);
672 for (size_t ii = 0; ii < dimension; ii++) {
673 data_.map_[ii] = map[ii];
674 }
675 }
676
SetTfDimOrder(const size_t dimension,TensorFormat data_format)677 inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
678 if (dimension == 5) {
679 CHECK(dimension == data_.dimension_);
680 data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
681 MklDnnDims3D::Dim3d_D;
682 data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
683 MklDnnDims3D::Dim3d_H;
684 data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
685 MklDnnDims3D::Dim3d_W;
686 data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
687 MklDnnDims3D::Dim3d_C;
688 data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
689 MklDnnDims3D::Dim3d_N;
690 } else {
691 CHECK_EQ(dimension, 4);
692 CHECK(dimension == data_.dimension_);
693 data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
694 data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
695 data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
696 data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
697 }
698 }
699
SetTfDimOrder(const size_t dimension,memory::format format)700 inline void SetTfDimOrder(const size_t dimension, memory::format format) {
701 TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
702 SetTfDimOrder(dimension, data_format);
703 }
704
GetTfToMklDimMap()705 inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
TfDimIdx(int index)706 inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
TfDimSize(int index)707 inline int64 TfDimSize(int index) const {
708 return data_.sizes_[TfDimIdx(index)];
709 }
710
711 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
712 /// corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)713 inline bool IsMklChannelDim(int d) const {
714 return TfDimIdx(d) == MklDnnDims::Dim_C;
715 }
716 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
717 /// corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)718 inline bool IsMklBatchDim(int d) const {
719 return TfDimIdx(d) == MklDnnDims::Dim_N;
720 }
721 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
722 /// corresponds to MKL's Width dimension.
IsMklWidthDim(int d)723 inline bool IsMklWidthDim(int d) const {
724 return TfDimIdx(d) == MklDnnDims::Dim_W;
725 }
726 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
727 /// corresponds to MKL's Height dimension.
IsMklHeightDim(int d)728 inline bool IsMklHeightDim(int d) const {
729 return TfDimIdx(d) == MklDnnDims::Dim_H;
730 }
731
732 /// Check if the TF-Mkl dimension ordering map specifies if the input
733 /// tensor is in NCHW format.
IsTensorInNCHWFormat()734 inline bool IsTensorInNCHWFormat() const {
735 TensorFormat data_format = FORMAT_NCHW;
736 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
737 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
738 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
739 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
740 }
741
742 /// Check if the TF-Mkl dimension ordering map specifies if the input
743 /// tensor is in NHWC format.
IsTensorInNHWCFormat()744 inline bool IsTensorInNHWCFormat() const {
745 TensorFormat data_format = FORMAT_NHWC;
746 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
747 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
748 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
749 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
750 }
751
752 /// The following methods are used for serializing and de-serializing the
753 /// contents of the mklshape object.
754 /// The data is serialized in this order
755 /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
756
757 /// Size of buffer to hold the serialized object, the size is computed by
758 /// following above mentioned order
GetSerializeBufferSize()759 inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
760
SerializeMklDnnShape(unsigned char * buf,size_t buf_size)761 void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
762 CHECK(buf_size >= GetSerializeBufferSize())
763 << "Buffer size is too small to SerializeMklDnnShape";
764 *reinterpret_cast<MklShapeData*>(buf) = data_;
765 }
766
DeSerializeMklDnnShape(const unsigned char * buf,size_t buf_size)767 void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
768 // Make sure buffer holds at least is_mkl_tensor_.
769 CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
770 << "Buffer size is too small in DeSerializeMklDnnShape";
771
772 const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
773 if (is_mkl_tensor) { // If it is an MKL Tensor then read the rest
774 CHECK(buf_size >= GetSerializeBufferSize())
775 << "Buffer size is too small in DeSerializeMklDnnShape";
776 data_ = *reinterpret_cast<const MklShapeData*>(buf);
777 }
778 }
779 };
780
781 #endif
782
783 // List of MklShape objects. Used in Concat/Split layers.
784
785 #ifndef INTEL_MKL_ML_ONLY
786 typedef std::vector<MklDnnShape> MklDnnShapeList;
787 #else
788 typedef std::vector<MklShape> MklShapeList;
789 #endif
790
791 #ifdef INTEL_MKL_ML_ONLY
792 // Check if all tensors specified by MklShapes are MKL tensors.
AreAllMklTensors(const MklShapeList & shapes)793 inline bool AreAllMklTensors(const MklShapeList& shapes) {
794 for (auto& s : shapes) {
795 if (!s.IsMklTensor()) {
796 return false;
797 }
798 }
799 return true;
800 }
801
802 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & mkl_tensor,const MklShape & mkl_shape)803 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
804 const MklShape& mkl_shape) {
805 Tensor output_tensor;
806 TensorShape output_shape;
807
808 for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
809 // Outermost to innermost dimension
810 output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
811 }
812
813 // Allocate output tensor.
814 context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
815
816 dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
817 void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
818 void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
819
820 if (mkl_tensor.NumElements() != 0) {
821 mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
822 }
823
824 return output_tensor;
825 }
826 #else
827 using mkldnn::stream;
828 template <typename T>
829 class MklDnnData;
830
831 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & mkl_tensor,const MklDnnShape & mkl_shape)832 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
833 const MklDnnShape& mkl_shape) {
834 Tensor output_tensor;
835 try {
836 if (!mkl_shape.IsMklTensor())
837 return mkl_tensor; // return input since it is already TF tensor
838
839 TensorShape output_shape = mkl_shape.GetTfShape();
840
841 // Allocate output tensor.
842 context->allocate_temp(DataTypeToEnum<T>::v(), output_shape,
843 &output_tensor);
844
845 auto cpu_engine = engine(engine::cpu, 0);
846 MklDnnData<T> input(&cpu_engine);
847
848 // Get Mkl layout of input tensor.
849 auto input_mkl_md = mkl_shape.GetMklLayout();
850 auto output_tf_md = mkl_shape.GetTfLayout();
851 auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
852 input.SetUsrMem(input_mkl_md, &mkl_tensor);
853
854 // reorder
855 if (input.IsReorderNeeded(output_tf_pd)) {
856 std::vector<primitive> net;
857 CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, &output_tensor, &net),
858 true);
859 stream(stream::kind::eager).submit(net).wait();
860 } else {
861 // If not, just forward input tensor to output tensor.
862 CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape));
863 }
864 } catch (mkldnn::error& e) {
865 string error_msg = "Status: " + std::to_string(e.status) +
866 ", message: " + string(e.message) + ", in file " +
867 string(__FILE__) + ":" + std::to_string(__LINE__);
868 LOG(FATAL) << "Operation received an exception: " << error_msg;
869 }
870 return output_tensor;
871 }
872 #endif
873
874 // Get the MKL shape from the second string tensor
875 #ifdef INTEL_MKL_ML_ONLY
GetMklShape(OpKernelContext * ctext,int n,MklShape * mklshape)876 inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
877 mklshape->DeSerializeMklShape(
878 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
879 .flat<uint8>()
880 .data(),
881 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
882 .flat<uint8>()
883 .size() *
884 sizeof(uint8));
885 }
886 #else
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape)887 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
888 mklshape->DeSerializeMklDnnShape(
889 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
890 .flat<uint8>()
891 .data(),
892 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
893 .flat<uint8>()
894 .size() *
895 sizeof(uint8));
896 }
897 #endif
898
899 // Gets the actual input
MklGetInput(OpKernelContext * ctext,int n)900 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
901 return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
902 }
903
GetMklInputList(OpKernelContext * ctext,StringPiece name,OpInputList * input_tensors)904 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
905 OpInputList* input_tensors) {
906 CHECK_NOTNULL(input_tensors);
907 ctext->input_list(name, input_tensors);
908 }
909
910 #ifdef INTEL_MKL_ML_ONLY
911
GetMklShapeList(OpKernelContext * ctext,StringPiece name,MklShapeList * mkl_shapes)912 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
913 MklShapeList* mkl_shapes) {
914 OpInputList input_mkl_tensors;
915 GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
916
917 for (int i = 0; i < input_mkl_tensors.size(); i++) {
918 (*mkl_shapes)[i].DeSerializeMklShape(
919 input_mkl_tensors[i].flat<uint8>().data(),
920 input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
921 }
922 }
923
924 #else
925
GetMklShapeList(OpKernelContext * ctext,StringPiece name,MklDnnShapeList * mkl_shapes)926 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
927 MklDnnShapeList* mkl_shapes) {
928 OpInputList input_mkl_tensors;
929 GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
930
931 for (int i = 0; i < input_mkl_tensors.size(); i++) {
932 (*mkl_shapes)[i].DeSerializeMklDnnShape(
933 input_mkl_tensors[i].flat<uint8>().data(),
934 input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
935 }
936 }
937
938 #endif
939
940 #ifndef INTEL_MKL_ML_ONLY
941 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
942 /// If the input tensor is in MKL layout, then obtains TensorShape from
943 /// MklShape.
GetTfShape(OpKernelContext * context,size_t input_idx)944 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
945 // Sanity check.
946 CHECK_NOTNULL(context);
947 CHECK_LT(input_idx, context->num_inputs());
948
949 MklDnnShape input_mkl_shape;
950 GetMklShape(context, input_idx, &input_mkl_shape);
951 if (input_mkl_shape.IsMklTensor()) {
952 return input_mkl_shape.GetTfShape();
953 } else {
954 const Tensor& t = MklGetInput(context, input_idx);
955 return t.shape();
956 }
957 }
958 #endif
959
960 #ifdef INTEL_MKL_ML_ONLY
961 // Allocate the second output tensor that will contain
962 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklShape & mkl_shape)963 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
964 const MklShape& mkl_shape) {
965 Tensor* second_tensor = nullptr;
966 TensorShape second_shape;
967 second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
968 OP_REQUIRES_OK(ctext, ctext->allocate_output(
969 GetTensorMetaDataIndex(n, ctext->num_outputs()),
970 second_shape, &second_tensor));
971 mkl_shape.SerializeMklShape(
972 second_tensor->flat<uint8>().data(),
973 second_tensor->flat<uint8>().size() * sizeof(uint8));
974 }
975
976 #else
977 // Allocate the second output tensor that will contain
978 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklDnnShape & mkl_shape)979 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
980 const MklDnnShape& mkl_shape) {
981 Tensor* second_tensor = nullptr;
982 TensorShape second_shape;
983 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
984 OP_REQUIRES_OK(ctext, ctext->allocate_output(
985 GetTensorMetaDataIndex(n, ctext->num_outputs()),
986 second_shape, &second_tensor));
987 mkl_shape.SerializeMklDnnShape(
988 second_tensor->flat<uint8>().data(),
989 second_tensor->flat<uint8>().size() * sizeof(uint8));
990 }
991 #endif
992
993 #ifdef INTEL_MKL_ML_ONLY
994 // Allocate the output tensor, create a second output tensor that will contain
995 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,Tensor ** output,const TensorShape & tf_shape,const MklShape & mkl_shape)996 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
997 Tensor** output,
998 const TensorShape& tf_shape,
999 const MklShape& mkl_shape) {
1000 Tensor* second_tensor = nullptr;
1001 TensorShape second_shape;
1002 second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
1003 OP_REQUIRES_OK(
1004 ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
1005 tf_shape, output));
1006 OP_REQUIRES_OK(ctext, ctext->allocate_output(
1007 GetTensorMetaDataIndex(n, ctext->num_outputs()),
1008 second_shape, &second_tensor));
1009 mkl_shape.SerializeMklShape(
1010 second_tensor->flat<uint8>().data(),
1011 second_tensor->flat<uint8>().size() * sizeof(uint8));
1012 }
1013
1014 #else
1015 // Allocate the output tensor, create a second output tensor that will contain
1016 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,Tensor ** output,const TensorShape & tf_shape,const MklDnnShape & mkl_shape)1017 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
1018 Tensor** output,
1019 const TensorShape& tf_shape,
1020 const MklDnnShape& mkl_shape) {
1021 Tensor* second_tensor = nullptr;
1022 TensorShape second_shape;
1023 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
1024 OP_REQUIRES_OK(
1025 ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
1026 tf_shape, output));
1027 OP_REQUIRES_OK(ctext, ctext->allocate_output(
1028 GetTensorMetaDataIndex(n, ctext->num_outputs()),
1029 second_shape, &second_tensor));
1030 mkl_shape.SerializeMklDnnShape(
1031 second_tensor->flat<uint8>().data(),
1032 second_tensor->flat<uint8>().size() * sizeof(uint8));
1033 }
1034 #endif
1035
1036 // Allocates a temp tensor and returns the data buffer for temporary storage.
1037 // Currently
1038 #ifndef INTEL_MKL_ML_ONLY
1039 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,const memory::primitive_desc & pd,void ** buf_out)1040 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
1041 const memory::primitive_desc& pd, void** buf_out) {
1042 TensorShape tf_shape;
1043
1044 tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
1045 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
1046 tf_shape, tensor_out));
1047 *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
1048 }
1049 #else
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,dnnLayout_t lt_buff,void ** buf_out)1050 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
1051 dnnLayout_t lt_buff, void** buf_out) {
1052 TensorShape tf_shape;
1053
1054 tf_shape.AddDim(
1055 dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(lt_buff)) /
1056 sizeof(float) +
1057 1);
1058 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::v(),
1059 tf_shape, tensor_out));
1060 *buf_out = static_cast<void*>(tensor_out->flat<float>().data());
1061 }
1062
1063 #endif
1064 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,TensorShape tf_shape)1065 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
1066 TensorShape tf_shape) {
1067 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
1068 tf_shape, tensor_out));
1069 }
1070
GetStridesFromSizes(TensorFormat data_format,size_t * strides,const size_t * sizes)1071 inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
1072 const size_t* sizes) {
1073 // MKL requires strides in NCHW
1074 if (data_format == FORMAT_NHWC) {
1075 strides[0] = sizes[2];
1076 strides[1] = sizes[0] * sizes[2];
1077 strides[2] = 1;
1078 strides[3] = sizes[0] * sizes[1] * sizes[2];
1079 } else {
1080 strides[0] = 1;
1081 strides[1] = sizes[0];
1082 strides[2] = sizes[0] * sizes[1];
1083 strides[3] = sizes[0] * sizes[1] * sizes[2];
1084 }
1085 }
1086
1087 #ifdef INTEL_MKL_ML_ONLY
MklSizesToTFSizes(OpKernelContext * context,TensorFormat data_format_,const MklShape & mkl_shape,TensorShape * tf_shape)1088 inline void MklSizesToTFSizes(OpKernelContext* context,
1089 TensorFormat data_format_,
1090 const MklShape& mkl_shape,
1091 TensorShape* tf_shape) {
1092 size_t tf_dim = mkl_shape.GetDimension();
1093 const size_t* tf_sizes = mkl_shape.GetSizes();
1094
1095 OP_REQUIRES(context, tf_dim == 4,
1096 errors::InvalidArgument("MKLSizesToTFSizes: size must be 4-dim"));
1097 std::vector<int32> sizes;
1098
1099 sizes.push_back(tf_sizes[3]);
1100
1101 if (data_format_ == FORMAT_NHWC) {
1102 sizes.push_back(tf_sizes[1]);
1103 sizes.push_back(tf_sizes[0]);
1104 sizes.push_back(tf_sizes[2]);
1105 } else {
1106 sizes.push_back(tf_sizes[2]);
1107 sizes.push_back(tf_sizes[1]);
1108 sizes.push_back(tf_sizes[0]);
1109 }
1110
1111 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
1112 }
1113 #endif
1114
GetMklTensorDimIndex(char dimension)1115 inline int32 GetMklTensorDimIndex(char dimension) {
1116 switch (dimension) {
1117 case 'N':
1118 return MklDims::N;
1119 case 'C':
1120 return MklDims::C;
1121 case 'H':
1122 return MklDims::H;
1123 case 'W':
1124 return MklDims::W;
1125 default:
1126 LOG(FATAL) << "Invalid dimension: " << dimension;
1127 return -1; // Avoid compiler warning about missing return value
1128 }
1129 }
1130
1131 #ifdef INTEL_MKL_ML_ONLY
GetMklTensorDim(const MklShape & mkl_shape,char dimension)1132 inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
1133 int index = GetMklTensorDimIndex(dimension);
1134 CHECK(index >= 0 && index < mkl_shape.GetDimension())
1135 << "Invalid index from the dimension: " << index << ", " << dimension;
1136 return mkl_shape.dim_size(index);
1137 }
1138 #endif
1139
CopyMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1140 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
1141 int idx_out) {
1142 int num_inputs = context->num_inputs();
1143 int num_outputs = context->num_outputs();
1144 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1145 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
1146 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1147 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
1148
1149 const Tensor& data = context->input(idx_data_in);
1150 const Tensor& meta = context->input(idx_meta_in);
1151 Tensor output(data.dtype());
1152 Tensor meta_output(meta.dtype());
1153
1154 // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
1155 CHECK(output.CopyFrom(data, data.shape()));
1156 CHECK(meta_output.CopyFrom(meta, meta.shape()));
1157 context->set_output(idx_data_out, output);
1158 context->set_output(idx_meta_out, meta_output);
1159 }
1160
1161 #ifdef INTEL_MKL_ML_ONLY
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)1162 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
1163 int idx_out,
1164 const TensorShape& shape) {
1165 int num_inputs = context->num_inputs();
1166 int num_outputs = context->num_outputs();
1167 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1168 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1169
1170 const Tensor& data = context->input(idx_data_in);
1171 MklShape mkl_shape_output;
1172 mkl_shape_output.SetMklTensor(false);
1173 AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
1174 Tensor output(data.dtype());
1175 // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
1176 CHECK(output.CopyFrom(data, shape));
1177 context->set_output(idx_data_out, output);
1178 }
1179 #else
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)1180 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
1181 int idx_out,
1182 const TensorShape& shape) {
1183 int num_inputs = context->num_inputs();
1184 int num_outputs = context->num_outputs();
1185 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1186 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1187
1188 const Tensor& data = context->input(idx_data_in);
1189 MklDnnShape mkl_shape_output;
1190 mkl_shape_output.SetMklTensor(false);
1191 AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
1192 Tensor output(data.dtype());
1193 // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
1194 CHECK(output.CopyFrom(data, shape));
1195 context->set_output(idx_data_out, output);
1196 }
1197 #endif
1198
1199 #ifdef INTEL_MKL_ML_ONLY
1200
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1201 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
1202 int idx_out) {
1203 int num_inputs = context->num_inputs();
1204 int num_outputs = context->num_outputs();
1205 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1206 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1207
1208 MklShape mkl_shape_output;
1209 mkl_shape_output.SetMklTensor(false);
1210 AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
1211 if (IsRefType(context->input_dtype(idx_data_in))) {
1212 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1213 } else {
1214 context->set_output(idx_data_out, context->input(idx_data_in));
1215 }
1216 }
1217
1218 #else
1219
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1220 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
1221 int idx_out) {
1222 int num_inputs = context->num_inputs();
1223 int num_outputs = context->num_outputs();
1224 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1225 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1226
1227 MklDnnShape dnn_shape_output;
1228 dnn_shape_output.SetMklTensor(false);
1229 AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
1230 if (IsRefType(context->input_dtype(idx_data_in))) {
1231 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1232 } else {
1233 context->set_output(idx_data_out, context->input(idx_data_in));
1234 }
1235 }
1236
1237 #endif
1238
ForwardMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)1239 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
1240 int idx_out) {
1241 int num_inputs = context->num_inputs();
1242 int num_outputs = context->num_outputs();
1243 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1244 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
1245 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1246 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
1247
1248 if (IsRefType(context->input_dtype(idx_data_in))) {
1249 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1250 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
1251 } else {
1252 context->set_output(idx_data_out, context->input(idx_data_in));
1253 context->set_output(idx_meta_out, context->input(idx_meta_in));
1254 }
1255 }
1256
1257 #ifndef INTEL_MKL_ML_ONLY
1258 // Set a dummy MKLDNN shape (called when the output is in TF format)
SetDummyMklDnnShapeOutput(OpKernelContext * context,uint32 idx_data_out)1259 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
1260 uint32 idx_data_out) {
1261 MklDnnShape mkl_shape_output;
1262 mkl_shape_output.SetMklTensor(false);
1263 AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
1264 }
1265
ForwardMklTensorInToOutWithMklShape(OpKernelContext * context,int idx_in,int idx_out,const MklDnnShape & mkl_shape)1266 inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
1267 int idx_in, int idx_out,
1268 const MklDnnShape& mkl_shape) {
1269 int num_inputs = context->num_inputs();
1270 int num_outputs = context->num_outputs();
1271 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
1272 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
1273
1274 AllocateOutputSetMklShape(context, idx_out, mkl_shape);
1275
1276 if (IsRefType(context->input_dtype(idx_data_in))) {
1277 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
1278 } else {
1279 context->set_output(idx_data_out, context->input(idx_data_in));
1280 }
1281 }
1282 #endif
1283
1284 // Forward the MKL shape ONLY (used in elementwise and other ops where
1285 // we call the eigen implementation and MKL shape is not used)
ForwardMklMetaDataInToOut(OpKernelContext * context,uint32 idx_data_in,uint32_t idx_data_out)1286 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
1287 uint32 idx_data_in,
1288 uint32_t idx_data_out) {
1289 uint32 idx_meta_in =
1290 GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
1291 uint32 idx_meta_out =
1292 GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
1293
1294 if (IsRefType(context->input_dtype(idx_data_in))) {
1295 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
1296 } else {
1297 context->set_output(idx_meta_out, context->input(idx_meta_in));
1298 }
1299 }
1300
1301 #ifdef INTEL_MKL_ML_ONLY
1302 // Set a dummy MKL shape (called when the output is in TF format)
SetDummyMklShapeOutput(OpKernelContext * context,uint32 idx_data_out)1303 inline void SetDummyMklShapeOutput(OpKernelContext* context,
1304 uint32 idx_data_out) {
1305 MklShape mkl_shape_output;
1306 mkl_shape_output.SetMklTensor(false);
1307 AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
1308 }
1309 // We don't need these functions in MKLDNN. We have defined equality operator
1310 // on MklDnnShape class directly.
1311
1312 // Checks if the TF shape for both MKL tensors is the same or not
1313 // Returns: true if both TF shapes are the same, false otherwise
MklCompareShapes(const MklShape * input_shape_0,const MklShape * input_shape_1)1314 inline bool MklCompareShapes(const MklShape* input_shape_0,
1315 const MklShape* input_shape_1) {
1316 // Check for number of dimensions
1317 if (input_shape_0->GetDimension() != input_shape_1->GetDimension()) {
1318 return false;
1319 }
1320
1321 // Check size of each dimension
1322 size_t ndims = input_shape_0->GetDimension();
1323 for (size_t i = 0; i < ndims; i++) {
1324 if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
1325 return false;
1326 }
1327 }
1328
1329 return true;
1330 }
1331
1332 // Checks if the TF shape for both tensors is the same or not
1333 // Returns: true if TF shapes for both are the same, false otherwise
MklCompareShapes(const MklShape * input_shape_0,const TensorShape * input_shape_1)1334 inline bool MklCompareShapes(const MklShape* input_shape_0,
1335 const TensorShape* input_shape_1) {
1336 // Check for number of dimensions
1337 if (input_shape_0->GetDimension() != input_shape_1->dims()) {
1338 return false;
1339 }
1340
1341 // Check size of each dimension
1342 size_t ndims = input_shape_0->GetDimension();
1343 for (size_t i = 0; i < ndims; i++) {
1344 if (input_shape_0->tf_dim_size(i) != input_shape_1->dim_size(i)) {
1345 return false;
1346 }
1347 }
1348
1349 return true;
1350 }
1351
1352 // Checks if the TF shape for both tensors is the same or not
1353 // Returns: true if TF shapes for both are the same, false otherwise
MklCompareShapes(const TensorShape * input_shape_0,const MklShape * input_shape_1)1354 inline bool MklCompareShapes(const TensorShape* input_shape_0,
1355 const MklShape* input_shape_1) {
1356 return MklCompareShapes(input_shape_1, input_shape_0);
1357 }
1358
1359 // Checks if the TF shape for both tensors is the same or not
1360 // Returns: true if TF shapes for both are the same, false otherwise
MklCompareShapes(const TensorShape * input_shape_0,const TensorShape * input_shape_1)1361 inline bool MklCompareShapes(const TensorShape* input_shape_0,
1362 const TensorShape* input_shape_1) {
1363 // Check for number of dimensions
1364 if (input_shape_0->dims() != input_shape_1->dims()) {
1365 return false;
1366 }
1367
1368 // Check size of each dimension
1369 size_t ndims = input_shape_0->dims();
1370 for (size_t i = 0; i < ndims; i++) {
1371 if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
1372 return false;
1373 }
1374 }
1375
1376 return true;
1377 }
1378
1379 // These functions do not compile with MKL-DNN since mkl.h is missing.
1380 // We may need to remove them later.
1381 // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
1382 // out.
MklNHWCToNCHW(const Tensor & input,Tensor ** output)1383 inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
1384 const float* buf_in = input.flat<float>().data();
1385 float* buf_out = (*output)->flat<float>().data();
1386
1387 int64 N = input.dim_size(0);
1388 int64 H = input.dim_size(1);
1389 int64 W = input.dim_size(2);
1390 int64 C = input.dim_size(3);
1391 int64 stride_n = H * W * C;
1392 #pragma omp parallel for num_threads(16)
1393 for (int64 n = 0; n < N; ++n) {
1394 mkl_somatcopy('R', 'T', H * W, C, 1, buf_in + n * stride_n, C,
1395 buf_out + n * stride_n, H * W);
1396 }
1397 }
1398
MklNCHWToNHWC(const Tensor & input,Tensor ** output)1399 inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
1400 const float* buf_in = input.flat<float>().data();
1401 float* buf_out = (*output)->flat<float>().data();
1402
1403 int64 N = (*output)->dim_size(0);
1404 int64 H = (*output)->dim_size(1);
1405 int64 W = (*output)->dim_size(2);
1406 int64 C = (*output)->dim_size(3);
1407 int64 stride_n = H * W * C;
1408 #pragma omp parallel for num_threads(16)
1409 for (int64 n = 0; n < N; ++n) {
1410 mkl_somatcopy('R', 'T', C, H * W, 1, buf_in + n * stride_n, H * W,
1411 buf_out + n * stride_n, C);
1412 }
1413 }
1414
1415 #endif
1416 // -------------------------------------------------------------------
1417
1418 #ifndef INTEL_MKL_ML_ONLY
1419
1420 /// Return MKL-DNN data type (memory::data_type) for input type T
1421 ///
1422 /// @input None
1423 /// @return memory::data_type corresponding to type T
1424 template <typename T>
1425 static memory::data_type MklDnnType();
1426
1427 /// Instantiation for float type. Add similar instantiations for other
1428 /// type if needed.
1429 template <>
1430 memory::data_type MklDnnType<float>() {
1431 return memory::data_type::f32;
1432 }
1433 template <>
1434 memory::data_type MklDnnType<quint8>() {
1435 return memory::data_type::u8;
1436 }
1437 template <>
1438 memory::data_type MklDnnType<qint8>() {
1439 return memory::data_type::s8;
1440 }
1441 template <>
1442 memory::data_type MklDnnType<qint32>() {
1443 return memory::data_type::s32;
1444 }
1445
1446 /// Map TensorFlow's data format into MKL-DNN 3D data format
1447 /// @input: TensorFlow data format
1448 /// @return: memory::format corresponding to TensorFlow data format;
1449 /// Fails with an error if invalid data format.
TFDataFormatToMklDnn3DDataFormat(TensorFormat format)1450 inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1451 if (format == FORMAT_NHWC)
1452 return memory::format::ndhwc;
1453 else if (format == FORMAT_NCHW)
1454 return memory::format::ncdhw;
1455 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1456 return memory::format::format_undef;
1457 }
1458
1459 /// Map TensorFlow's data format into MKL-DNN data format
1460 ///
1461 /// @input: TensorFlow data format
1462 /// @return: memory::format corresponding to TensorFlow data format;
1463 /// Fails with an error if invalid data format.
TFDataFormatToMklDnnDataFormat(TensorFormat format)1464 inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1465 if (format == FORMAT_NHWC)
1466 return memory::format::nhwc;
1467 else if (format == FORMAT_NCHW)
1468 return memory::format::nchw;
1469 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1470 return memory::format::format_undef;
1471 }
1472
1473 /// Map MKL-DNN data format to TensorFlow's data format
1474 ///
1475 /// @input: memory::format
1476 /// @return: Tensorflow data format corresponding to memory::format
1477 /// Fails with an error if invalid data format.
MklDnnDataFormatToTFDataFormat(memory::format format)1478 inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
1479 if (format == memory::format::nhwc || format == memory::format::ndhwc)
1480 return FORMAT_NHWC;
1481 else if (format == memory::format::nchw || format == memory::format::ncdhw)
1482 return FORMAT_NCHW;
1483 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1484
1485 // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1486 // that we don't come here.
1487 return FORMAT_NHWC;
1488 }
1489
1490 /// Map TensorShape object into memory::dims required by MKL-DNN
1491 ///
1492 /// This function will simply map input TensorShape into MKL-DNN dims
1493 /// naively. So it will preserve the order of dimensions. E.g., if
1494 /// input tensor is in NHWC format, then dims will be in NHWC format
1495 /// also.
1496 ///
1497 /// @input TensorShape object in shape
1498 /// @return memory::dims corresponding to TensorShape
TFShapeToMklDnnDims(const TensorShape & shape)1499 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1500 memory::dims dims(shape.dims());
1501 for (int d = 0; d < shape.dims(); ++d) {
1502 dims[d] = shape.dim_size(d);
1503 }
1504 return dims;
1505 }
1506
1507 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
1508 ///
1509 /// This function is a specific one than above function. It will map input
1510 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
1511 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1512 /// will be in NCHW format, and not in NHWC format.
1513 ///
1514 /// @input TensorShape object in shape
1515 /// @return memory::dims in MKL-DNN required NCHW format
TFShapeToMklDnnDimsInNCHW(const TensorShape & shape,TensorFormat format)1516 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1517 TensorFormat format) {
1518 // Check validity of format.
1519 CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1520 memory::format::format_undef);
1521
1522 int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1523 int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1524 int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1525 int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1526
1527 // MKL-DNN requires dimensions in NCHW format.
1528 return memory::dims({n, c, h, w});
1529 }
1530
TFShapeToMklDnnDimsInNCDHW(const TensorShape & shape,TensorFormat format)1531 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1532 TensorFormat format) {
1533 // Check validity of format.
1534 CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1535 memory::format::format_undef);
1536
1537 int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1538 int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1539 int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1540 int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1541 int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1542
1543 // MKL-DNN requires dimensions in NCDHW format.
1544 return memory::dims({n, c, d, h, w});
1545 }
1546
1547 /// Overloaded version of function above. Input parameters are
1548 /// self-explanatory.
MklDnnDimsInNCHW(const memory::dims & in_dims,TensorFormat format)1549 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1550 TensorFormat format) {
1551 // Check validity of format.
1552 CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1553 memory::format::format_undef);
1554
1555 int n = in_dims[GetTensorDimIndex(format, 'N')];
1556 int c = in_dims[GetTensorDimIndex(format, 'C')];
1557 int h = in_dims[GetTensorDimIndex(format, 'H')];
1558 int w = in_dims[GetTensorDimIndex(format, 'W')];
1559
1560 // MKL-DNN requires dimensions in NCHW format.
1561 return memory::dims({n, c, h, w});
1562 }
1563
1564 /// Map MklDnn memory::dims object into TensorShape object.
1565 ///
1566 /// This function will simply map input shape in MKL-DNN memory::dims format
1567 /// in Tensorflow's TensorShape object by preserving dimension order.
1568 ///
1569 /// @input MKL-DNN memory::dims object
1570 /// @output TensorShape corresponding to memory::dims
MklDnnDimsToTFShape(const memory::dims & dims)1571 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1572 std::vector<int32> shape(dims.size(), -1);
1573 for (int d = 0; d < dims.size(); d++) {
1574 shape[d] = dims[d];
1575 }
1576
1577 TensorShape ret;
1578 CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1579 return ret;
1580 }
1581
1582 /// Function to calculate strides given tensor shape in Tensorflow order
1583 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1584 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
1585 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1586 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1587 ///
1588 /// @input Tensorflow shape in memory::dims type
1589 /// @return memory::dims containing strides for the tensor.
CalculateTFStrides(const memory::dims & dims_tf_order)1590 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1591 CHECK_GT(dims_tf_order.size(), 0);
1592 memory::dims strides(dims_tf_order.size());
1593 int last_dim_idx = dims_tf_order.size() - 1;
1594 strides[last_dim_idx] = 1;
1595 for (int d = last_dim_idx - 1; d >= 0; d--) {
1596 strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1597 }
1598 return strides;
1599 }
1600
TFPaddingToMklDnnPadding(Padding pad)1601 inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
1602 // MKL-DNN only supports zero padding.
1603 return padding_kind::zero;
1604 }
1605
1606 /// Helper function to create memory descriptor in Blocked format
1607 ///
1608 /// @input: Tensor dimensions
1609 /// @input: strides corresponding to dimensions. One can use utility
1610 /// function such as CalculateTFStrides to compute strides
1611 /// for given dimensions.
1612 /// @return: memory::desc object corresponding to blocked memory format
1613 /// for given dimensions and strides.
CreateBlockedMemDescHelper(const memory::dims & dim,const memory::dims & strides,memory::data_type dtype)1614 inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
1615 const memory::dims& strides,
1616 memory::data_type dtype) {
1617 CHECK_EQ(dim.size(), strides.size());
1618
1619 // We have to construct memory descriptor in a C style. This is not at all
1620 // ideal but MKLDNN does not offer any API to construct descriptor in
1621 // blocked format except a copy constructor that accepts
1622 // mkldnn_memory_desc_t.
1623 mkldnn_memory_desc_t md;
1624 md.primitive_kind = mkldnn_memory;
1625 md.ndims = dim.size();
1626 md.format = mkldnn_blocked;
1627 md.data_type = memory::convert_to_c(dtype);
1628
1629 for (size_t i = 0; i < dim.size(); i++) {
1630 md.layout_desc.blocking.block_dims[i] = 1;
1631 md.layout_desc.blocking.strides[1][i] = 1;
1632 md.layout_desc.blocking.strides[0][i] = strides[i];
1633 md.layout_desc.blocking.padding_dims[i] = dim[i];
1634 md.layout_desc.blocking.offset_padding_to_data[i] = 0;
1635 md.dims[i] = dim[i];
1636 }
1637 md.layout_desc.blocking.offset_padding = 0;
1638
1639 return memory::desc(md);
1640 }
1641
1642 template <typename T>
1643 inline primitive FindOrCreateReorder(const memory* from, const memory* to);
1644 /*
1645 * Class to represent all the resources corresponding to a tensor in TensorFlow
1646 * that are required to execute an operation (such as Convolution).
1647 */
1648 template <typename T>
1649 class MklDnnData {
1650 private:
1651 /// MKL-DNN memory primitive for input user memory
1652 memory* user_memory_;
1653
1654 /// MKL-DNN memory primitive in case input or output reorder is needed.
1655 memory* reorder_memory_;
1656
1657 /// Operations memory descriptor
1658 memory::desc* op_md_;
1659 // flat to indicate if data is 3D or not.
1660 bool bIs3D;
1661 /// Operations temp buffer
1662 void* allocated_buffer_;
1663 /// CPU engine on which operation will be executed
1664 const engine* cpu_engine_;
1665
1666 public:
MklDnnData(const engine * e)1667 explicit MklDnnData(const engine* e)
1668 : user_memory_(nullptr),
1669 reorder_memory_(nullptr),
1670 op_md_(nullptr),
1671 allocated_buffer_(nullptr),
1672 cpu_engine_(e) {}
1673
~MklDnnData()1674 ~MklDnnData() {
1675 if (allocated_buffer_ != nullptr) {
1676 cpu_allocator()->DeallocateRaw(allocated_buffer_);
1677 }
1678 cpu_engine_ = nullptr; // We don't own this.
1679 delete (user_memory_);
1680 delete (reorder_memory_);
1681 delete (op_md_);
1682 }
1683
GetTensorBuffer(const Tensor * tensor)1684 inline void* GetTensorBuffer(const Tensor* tensor) const {
1685 CHECK_NOTNULL(tensor);
1686 return const_cast<void*>(
1687 static_cast<const void*>(tensor->flat<T>().data()));
1688 }
1689
SetIs3DData(bool bIs3D_)1690 void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
1691
GetIs3D()1692 bool GetIs3D() { return bIs3D; }
1693
1694 /// Set user memory primitive using specified dimensions, memory format and
1695 /// data_buffer. Function automatically uses element data type by using
1696 /// input type T used for creating call object.
1697 ///
1698 /// In a nutshell, function allows user to describe the input tensor to
1699 /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1700 /// memory format HWIO, and the buffer that contains actual values is
1701 /// pointed by data_buffer.
1702 inline void SetUsrMem(const memory::dims& dim, memory::format fm,
1703 void* data_buffer = nullptr) {
1704 auto md = memory::desc(dim, MklDnnType<T>(), fm);
1705 SetUsrMem(md, data_buffer);
1706 }
1707
SetUsrMem(const memory::dims & dim,memory::format fm,const Tensor * tensor)1708 inline void SetUsrMem(const memory::dims& dim, memory::format fm,
1709 const Tensor* tensor) {
1710 CHECK_NOTNULL(tensor);
1711 SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1712 }
1713
1714 /// Helper function to create memory descriptor in Blocked format
1715 ///
1716 /// @input: Tensor dimensions
1717 /// @input: strides corresponding to dimensions. One can use utility
1718 /// function such as CalculateTFStrides to compute strides
1719 /// for given dimensions.
1720 /// @return: memory::desc object corresponding to blocked memory format
1721 /// for given dimensions and strides.
CreateBlockedMemDesc(const memory::dims & dim,const memory::dims & strides)1722 static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1723 const memory::dims& strides) {
1724 return CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>());
1725 }
1726
1727 /// A version of SetUsrMem call that allows user to create memory in blocked
1728 /// format. So in addition to accepting dimensions, it also accepts strides.
1729 /// This allows user to create memory for tensor in a format that is not
1730 /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
1731 /// dimensional tensor as a native format. But by using blocked format, a user
1732 /// can create memory for 6D tensor.
1733 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1734 void* data_buffer = nullptr) {
1735 CHECK_EQ(dim.size(), strides.size());
1736 auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1737 SetUsrMem(blocked_md, data_buffer);
1738 }
1739
SetUsrMem(const memory::dims & dim,const memory::dims & strides,const Tensor * tensor)1740 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1741 const Tensor* tensor) {
1742 CHECK_NOTNULL(tensor);
1743 SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1744 }
1745
1746 /// A version of function to set user memory primitive that accepts memory
1747 /// descriptor directly, instead of accepting dimensions and format. This
1748 /// function is more generic that the one above, but the function above is
1749 /// sufficient in most cases.
1750 inline void SetUsrMem(const memory::desc& md, void* data_buffer = nullptr) {
1751 auto pd = memory::primitive_desc(md, *cpu_engine_);
1752 SetUsrMem(pd, data_buffer);
1753 }
1754
1755 /// A version of SetUsrMem with memory descriptor and tensor
SetUsrMem(const memory::desc & md,const Tensor * tensor)1756 inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1757 CHECK_NOTNULL(tensor);
1758 SetUsrMem(md, GetTensorBuffer(tensor));
1759 }
1760
1761 /// A version of function to set user memory primitive that accepts primitive
1762 /// descriptor directly, instead of accepting dimensions and format. This
1763 /// function is more generic that the one above, but the function above is
1764 /// sufficient in most cases.
1765 inline void SetUsrMem(const memory::primitive_desc& pd,
1766 void* data_buffer = nullptr) {
1767 CHECK_NOTNULL(cpu_engine_);
1768 if (user_memory_) delete user_memory_;
1769 // TODO(nhasabni): can we remove dynamic memory allocation?
1770 if (data_buffer) {
1771 user_memory_ = new memory(pd, data_buffer);
1772 } else {
1773 user_memory_ = new memory(pd);
1774 }
1775 }
1776
1777 /// A version of SetUsrMem with primitive descriptor and tensor
SetUsrMem(const memory::primitive_desc & pd,const Tensor * tensor)1778 inline void SetUsrMem(const memory::primitive_desc& pd,
1779 const Tensor* tensor) {
1780 CHECK_NOTNULL(tensor);
1781 SetUsrMem(pd, GetTensorBuffer(tensor));
1782 }
1783
1784 /// Get function for user memory primitive.
GetUsrMem()1785 inline const memory* GetUsrMem() const { return user_memory_; }
1786
1787 /// Get function for primitive descriptor of user memory primitive.
GetUsrMemPrimDesc()1788 inline const memory::primitive_desc GetUsrMemPrimDesc() const {
1789 CHECK_NOTNULL(user_memory_);
1790 return user_memory_->get_primitive_desc();
1791 }
1792
1793 /// Get function for descriptor of user memory.
GetUsrMemDesc()1794 inline memory::desc GetUsrMemDesc() {
1795 // This is ugly. Why MKL-DNN does not provide desc() method of const type??
1796 const memory::primitive_desc pd = GetUsrMemPrimDesc();
1797 return const_cast<memory::primitive_desc*>(&pd)->desc();
1798 }
1799
1800 /// Get function for data buffer of user memory primitive.
GetUsrMemDataHandle()1801 inline void* GetUsrMemDataHandle() const {
1802 CHECK_NOTNULL(user_memory_);
1803 return user_memory_->get_data_handle();
1804 }
1805
1806 /// Set function for data buffer of user memory primitive.
SetUsrMemDataHandle(void * data_buffer)1807 inline void SetUsrMemDataHandle(void* data_buffer) {
1808 CHECK_NOTNULL(user_memory_);
1809 CHECK_NOTNULL(data_buffer);
1810 user_memory_->set_data_handle(data_buffer);
1811 }
1812
1813 /// Set function for data buffer of user memory primitive.
SetUsrMemDataHandle(const Tensor * tensor)1814 inline void SetUsrMemDataHandle(const Tensor* tensor) {
1815 CHECK_NOTNULL(user_memory_);
1816 CHECK_NOTNULL(tensor);
1817 user_memory_->set_data_handle(GetTensorBuffer(tensor));
1818 }
1819
1820 /// allocate function for data buffer
AllocateBuffer(size_t size)1821 inline void AllocateBuffer(size_t size) {
1822 const int64 kMemoryAlginment = 64; // For AVX512 memory alignment.
1823 allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlginment, size);
1824 }
1825
GetAllocatedBuffer()1826 inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1827
1828 /// Get the memory primitive for input and output of an op. If inputs
1829 /// to an op require reorders, then this function returns memory primitive
1830 /// for reorder. Otherwise, it will return memory primitive for user memory.
1831 ///
1832 /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1833 /// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
1834 /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1835 /// primitive for F), then we need I_r and F_r to perform Conv2D.
GetOpMem()1836 inline const memory& GetOpMem() const {
1837 return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1838 }
1839
1840 /// Set memory descriptor of an operation in terms of dimensions and memory
1841 /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1842 /// but memory::format would be mkldnn::any because we want MKL-DNN to choose
1843 /// best layout/format for given input dimensions.
SetOpMemDesc(const memory::dims & dim,memory::format fm)1844 inline void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
1845 // TODO(nhasabni): can we remove dynamic memory allocation?
1846 op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1847 }
1848
1849 /// Get function for memory descriptor for an operation
GetOpMemDesc()1850 inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1851
1852 /// Predicate that checks if we need to reorder user's memory into memory
1853 /// pointed by op_pd.
1854 ///
1855 /// @input: op_pd - memory primitive descriptor of the given input of an
1856 /// operation
1857 /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::primitive_desc & op_pd)1858 inline bool IsReorderNeeded(const memory::primitive_desc& op_pd) const {
1859 CHECK_NOTNULL(user_memory_);
1860 return op_pd != user_memory_->get_primitive_desc();
1861 }
1862
1863 /// Predicate that checks if we need to reorder user's memory into memory
1864 /// based on the provided format.
1865 ///
1866 /// @input: target_format - memory format of the given input of an
1867 /// operation
1868 /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::format & target_format)1869 inline bool IsReorderNeeded(const memory::format& target_format) const {
1870 CHECK_NOTNULL(user_memory_);
1871 return target_format !=
1872 user_memory_->get_primitive_desc().desc().data.format;
1873 }
1874
1875 /// Function to create a reorder from memory pointed by from to memory pointed
1876 /// by to. Returns created primitive.
CreateReorder(const memory * from,const memory * to)1877 inline primitive CreateReorder(const memory* from, const memory* to) const {
1878 CHECK_NOTNULL(from);
1879 CHECK_NOTNULL(to);
1880 return reorder(*from, *to);
1881 }
1882
1883 /// Function to handle input reordering
1884 ///
1885 /// Check if we need to reorder this input of an operation.
1886 /// Return true and allocate reorder memory primitive if reorder is needed.
1887 /// Otherwise, return false and do not allocate reorder memory primitive.
1888 ///
1889 /// To check if reorder is needed, this function compares memory primitive
1890 /// descriptor of an operation (op_pd) for the given input with the
1891 /// user-specified memory primitive descriptor.
1892 ///
1893 /// @input: op_pd - memory primitive descriptor of the given input of an
1894 /// operation
1895 /// @input: net - net to which to add reorder primitive in case it is needed.
1896 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::primitive_desc & op_pd,std::vector<primitive> * net)1897 inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1898 std::vector<primitive>* net) {
1899 CHECK_NOTNULL(net);
1900 CHECK_NOTNULL(user_memory_);
1901 if (IsReorderNeeded(op_pd)) {
1902 // TODO(nhasabni): can we remove dynamic memory allocation?
1903 reorder_memory_ = new memory(op_pd);
1904 net->push_back(CreateReorder(user_memory_, reorder_memory_));
1905 return true;
1906 }
1907 return false;
1908 }
1909
1910 /// TODO: this is a faster path with reorder primitive cache compared with
1911 /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
1912 /// slow path in the future
CheckReorderToOpMem(const memory::primitive_desc & op_pd)1913 inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) {
1914 CHECK_NOTNULL(user_memory_);
1915 if (IsReorderNeeded(op_pd)) {
1916 // TODO(nhasabni): can we remove dynamic memory allocation?
1917 // primitive reuse don't allow two same reorder prim in
1918 // one stream, so submit it immediately
1919 reorder_memory_ = new memory(op_pd);
1920 std::vector<primitive> net;
1921 net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
1922 stream(stream::kind::eager).submit(net).wait();
1923 return true;
1924 }
1925 return false;
1926 }
1927
1928 /// Overloaded version of above function that accepts memory buffer
1929 /// where output of reorder needs to be stored.
1930 ///
1931 /// @input: op_pd - memory primitive descriptor of the given input of an
1932 /// operation
1933 /// @reorder_data_handle - memory buffer where output of reorder needs to be
1934 /// stored. Primitive does not check if buffer is
1935 /// enough size to write.
1936 /// @input: net - net to which to add reorder primitive in case it is needed.
1937 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::primitive_desc & op_pd,void * reorder_data_handle,std::vector<primitive> * net)1938 inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1939 void* reorder_data_handle,
1940 std::vector<primitive>* net) {
1941 CHECK_NOTNULL(net);
1942 CHECK_NOTNULL(reorder_data_handle);
1943 CHECK_NOTNULL(user_memory_);
1944 if (IsReorderNeeded(op_pd)) {
1945 // TODO(nhasabni): can we remove dynamic memory allocation?
1946 reorder_memory_ = new memory(op_pd, reorder_data_handle);
1947 net->push_back(CreateReorder(user_memory_, reorder_memory_));
1948 return true;
1949 }
1950 return false;
1951 }
1952
1953 /// TODO: this is a faster path with reorder primitive cache compared with
1954 /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
1955 /// slow path in the future
CheckReorderToOpMem(const memory::primitive_desc & op_pd,void * reorder_data_handle)1956 inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1957 void* reorder_data_handle) {
1958 CHECK_NOTNULL(reorder_data_handle);
1959 CHECK_NOTNULL(user_memory_);
1960 if (IsReorderNeeded(op_pd)) {
1961 // TODO(nhasabni): can we remove dynamic memory allocation?
1962 // primitive reuse don't allow two same reorder prim in
1963 // one stream, so submit it immediately
1964 std::vector<primitive> net;
1965 reorder_memory_ = new memory(op_pd, reorder_data_handle);
1966 net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
1967 stream(stream::kind::eager).submit(net).wait();
1968 return true;
1969 }
1970 return false;
1971 }
1972
1973 /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1974 /// where output of reorder needs to be stored.
1975 ///
1976 /// @input: op_pd - memory primitive descriptor of the given input of an
1977 /// operation
1978 /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1979 /// reorder. Primitive does not check if buffer is
1980 /// enough size to write.
1981 /// @input: net - net to which to add reorder primitive in case it is needed.
1982 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::primitive_desc & op_pd,Tensor * reorder_tensor,std::vector<primitive> * net)1983 inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1984 Tensor* reorder_tensor,
1985 std::vector<primitive>* net) {
1986 CHECK_NOTNULL(net);
1987 CHECK_NOTNULL(reorder_tensor);
1988 return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
1989 }
1990
1991 /// TODO: this is a faster path with reorder primitive cache compared with
1992 /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
1993 /// slow path in the future
CheckReorderToOpMem(const memory::primitive_desc & op_pd,Tensor * reorder_tensor)1994 inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
1995 Tensor* reorder_tensor) {
1996 CHECK_NOTNULL(reorder_tensor);
1997 return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
1998 }
1999
2000 /// Function to handle output reorder
2001 ///
2002 /// This function performs very similar functionality as input reordering
2003 /// function above. The only difference is that this function does not add
2004 /// reorder primitive to the net. The reason for this is: the reorder
2005 /// primitive for output needs to be added to the list only after operation
2006 /// has executed. But we need to prepare a temporary buffer in case output
2007 /// reorder is needed. And this temporary buffer will hold the output of
2008 /// an operation before it is fed to reorder primitive.
2009 ///
2010 /// @input memory primitive descriptor for the given output of an operation
2011 /// @return: true in case reorder of output is needed; false, otherwise.
PrepareReorderToUserMemIfReq(const memory::primitive_desc & op_pd)2012 inline bool PrepareReorderToUserMemIfReq(
2013 const memory::primitive_desc& op_pd) {
2014 CHECK_NOTNULL(user_memory_);
2015 if (IsReorderNeeded(op_pd)) {
2016 // TODO(nhasabni): can we remove dynamic memory allocation?
2017 reorder_memory_ = new memory(op_pd);
2018 return true;
2019 }
2020 return false;
2021 }
2022
2023 /// Function to actually insert reorder primitive in the net
2024 ///
2025 /// This function completes remaining part of output reordering. It inserts
2026 /// a reordering primitive from the temporary buffer that holds the output
2027 /// to the user-specified output buffer.
2028 ///
2029 /// @input: net - net to which to add reorder primitive
InsertReorderToUserMem(std::vector<primitive> * net)2030 inline void InsertReorderToUserMem(std::vector<primitive>* net) {
2031 CHECK_NOTNULL(net);
2032 CHECK_NOTNULL(user_memory_);
2033 CHECK_NOTNULL(reorder_memory_);
2034 net->push_back(CreateReorder(reorder_memory_, user_memory_));
2035 }
2036
2037 /// TODO: this is a faster path with reorder primitive cache compared with
2038 /// InsertReorderToUserMem(std::vector<primitive>* net), will remove
2039 /// slow path in the future
InsertReorderToUserMem()2040 inline void InsertReorderToUserMem() {
2041 CHECK_NOTNULL(user_memory_);
2042 CHECK_NOTNULL(reorder_memory_);
2043 // primitive reuse don't allow two same reorder prim in
2044 // one stream, so submit it immediately
2045 std::vector<primitive> net;
2046 net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
2047 stream(stream::kind::eager).submit(net).wait();
2048 }
2049 };
2050
2051 /// Base class for operations with reuse of primitives
2052 ///
2053 class MklPrimitive {
2054 public:
~MklPrimitive()2055 virtual ~MklPrimitive() {}
2056
2057 // Dummy data which MKL DNN never operates on
2058 unsigned char* DummyData = nullptr;
2059 };
2060
2061 const mkldnn::memory::dims NONE_DIMS = {};
2062
2063 //
2064 // LRUCache is a class which implements LRU (Least Recently Used) cache.
2065 // The implementation is similar to that of
2066 // tensorflow/core/platform/cloud/expiring_lru_cache.h
2067 // without its thread-safe part because the cache is supposed to be
2068 // used as thread local (for instance, MklPrimitive caching).
2069 //
2070 // The LRU list maintains objects in chronological order based on
2071 // creation time, with the least recently accessed object at the
2072 // tail of LRU list, while the most recently accessed object
2073 // at the head of LRU list.
2074 //
2075 // This class is used to maintain an upper bound on the total number of
2076 // cached items. When the cache reaches its capacity, the LRU item will
2077 // be removed and replaced by a new one from SetOp call.
2078 //
2079 template <typename T>
2080 class LRUCache {
2081 public:
LRUCache(size_t capacity)2082 explicit LRUCache(size_t capacity) {
2083 capacity_ = capacity;
2084 Clear();
2085 }
2086
GetOp(const string & key)2087 T* GetOp(const string& key) {
2088 auto it = cache_.find(key);
2089 if (it == cache_.end()) {
2090 return nullptr;
2091 }
2092
2093 // Move to the front of LRU list as the most recently accessed.
2094 lru_list_.erase(it->second.lru_iterator);
2095 lru_list_.push_front(it->first);
2096 it->second.lru_iterator = lru_list_.begin();
2097 return it->second.op;
2098 }
2099
SetOp(const string & key,T * op)2100 void SetOp(const string& key, T* op) {
2101 if (lru_list_.size() >= capacity_) {
2102 Delete();
2103 }
2104
2105 // Insert an entry to the front of the LRU list
2106 lru_list_.push_front(key);
2107 Entry entry(op, lru_list_.begin());
2108 cache_.emplace(std::make_pair(key, std::move(entry)));
2109 }
2110
Clear()2111 void Clear() {
2112 if (lru_list_.empty()) return;
2113
2114 // Clean up the cache
2115 cache_.clear();
2116 lru_list_.clear();
2117 }
2118
2119 private:
2120 struct Entry {
2121 // The entry's value.
2122 T* op;
2123
2124 // A list iterator pointing to the entry's position in the LRU list.
2125 std::list<string>::iterator lru_iterator;
2126
2127 // Constructor
EntryEntry2128 Entry(T* op, std::list<string>::iterator it) {
2129 this->op = op;
2130 this->lru_iterator = it;
2131 }
2132
2133 // Move construcctor
EntryEntry2134 Entry(Entry&& source) noexcept
2135 : lru_iterator(std::move(source.lru_iterator)) {
2136 op = std::move(source.op);
2137 source.op = std::forward<T*>(nullptr);
2138 }
2139
2140 // Destructor
~EntryEntry2141 ~Entry() {
2142 if (op != nullptr) delete op;
2143 }
2144 };
2145
2146 // Remove the least recently accessed entry from LRU list, which
2147 // is the tail of lru_list_. Update cache_ correspondingly.
Delete()2148 bool Delete() {
2149 if (lru_list_.empty()) return false;
2150 string key = lru_list_.back();
2151 lru_list_.pop_back();
2152 cache_.erase(key);
2153 return true;
2154 }
2155
2156 // Cache capacity
2157 size_t capacity_;
2158
2159 // The cache, a map from string key to a LRU entry.
2160 std::unordered_map<string, Entry> cache_;
2161
2162 // The LRU list of entries.
2163 // The front of the list contains the key of the most recently accessed
2164 // entry, while the back of the list is the least recently accessed entry.
2165 std::list<string> lru_list_;
2166 };
2167
2168 template <typename T>
2169 class MklPrimitiveFactory {
2170 public:
MklPrimitiveFactory()2171 MklPrimitiveFactory() {}
2172
~MklPrimitiveFactory()2173 ~MklPrimitiveFactory() {}
2174
GetOp(const string & key)2175 MklPrimitive* GetOp(const string& key) {
2176 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
2177 return lru_cache.GetOp(key);
2178 }
2179
SetOp(const string & key,MklPrimitive * op)2180 void SetOp(const string& key, MklPrimitive* op) {
2181 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
2182 lru_cache.SetOp(key, op);
2183 }
2184
2185 /// Function to decide whether HW has AVX512 or AVX2
2186 /// For those legacy device(w/o AVX512 and AVX2),
2187 /// MKL-DNN GEMM will be used.
IsLegacyPlatform()2188 static inline bool IsLegacyPlatform() {
2189 return (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
2190 !port::TestCPUFeature(port::CPUFeature::AVX2));
2191 }
2192
2193 /// Fuction to check whether primitive memory optimization is enabled
IsPrimitiveMemOptEnabled()2194 static inline bool IsPrimitiveMemOptEnabled() {
2195 bool is_primitive_mem_opt_enabled = true;
2196 TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
2197 &is_primitive_mem_opt_enabled));
2198 return is_primitive_mem_opt_enabled;
2199 }
2200
2201 private:
GetLRUCache()2202 static inline LRUCache<MklPrimitive>& GetLRUCache() {
2203 static const int kCapacity = 1024; // cache capacity
2204 static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
2205 return lru_cache_;
2206 }
2207 };
2208
2209 // utility class for creating keys of MKL primitive pool.
2210 class FactoryKeyCreator {
2211 public:
FactoryKeyCreator()2212 FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
2213
~FactoryKeyCreator()2214 ~FactoryKeyCreator() {}
2215
AddAsKey(const string & str)2216 void AddAsKey(const string& str) { Append(str); }
2217
AddAsKey(const mkldnn::memory::dims & dims)2218 void AddAsKey(const mkldnn::memory::dims& dims) {
2219 for (unsigned int i = 0; i < dims.size(); i++) {
2220 AddAsKey<int>(dims[i]);
2221 }
2222 }
2223
2224 template <typename T>
AddAsKey(const T data)2225 void AddAsKey(const T data) {
2226 auto buffer = reinterpret_cast<const char*>(&data);
2227 Append(StringPiece(buffer, sizeof(T)));
2228 }
2229
GetKey()2230 string GetKey() { return key_; }
2231
2232 private:
2233 string key_;
2234 const char delimiter = 'x';
2235 const int kMaxKeyLength = 256;
Append(StringPiece s)2236 void Append(StringPiece s) {
2237 key_.append(string(s));
2238 key_.append(1, delimiter);
2239 }
2240 };
2241
2242 static inline memory::format get_desired_format(int channel,
2243 bool is_2d = true) {
2244 memory::format fmt_desired = memory::format::any;
2245
2246 if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
2247 fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
2248 } else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
2249 (channel % 8) == 0) {
2250 fmt_desired = is_2d ? memory::format::nChw8c
2251 : memory::format::ncdhw; // no avx2 support for 3d yet.
2252 } else {
2253 fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
2254 }
2255 return fmt_desired;
2256 }
2257
2258 class MklReorderPrimitive : public MklPrimitive {
2259 public:
MklReorderPrimitive(const memory * from,const memory * to)2260 explicit MklReorderPrimitive(const memory* from, const memory* to) {
2261 Setup(from, to);
2262 }
~MklReorderPrimitive()2263 ~MklReorderPrimitive() {}
2264
GetPrimitive()2265 std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
2266
SetMemory(const memory * from,const memory * to)2267 void SetMemory(const memory* from, const memory* to) {
2268 context_.src_mem->set_data_handle(from->get_data_handle());
2269 context_.dst_mem->set_data_handle(to->get_data_handle());
2270 }
2271
2272 private:
2273 struct ReorderContext {
2274 std::shared_ptr<mkldnn::memory> src_mem;
2275 std::shared_ptr<mkldnn::memory> dst_mem;
2276 std::shared_ptr<primitive> reorder_prim;
ReorderContextReorderContext2277 ReorderContext()
2278 : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
2279 } context_;
2280
2281 engine cpu_engine_ = engine(engine::cpu, 0);
2282
Setup(const memory * from,const memory * to)2283 void Setup(const memory* from, const memory* to) {
2284 context_.src_mem.reset(new memory(
2285 {from->get_primitive_desc().desc(), cpu_engine_}, DummyData));
2286 context_.dst_mem.reset(
2287 new memory({to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
2288 context_.reorder_prim = std::make_shared<mkldnn::reorder>(
2289 reorder(*context_.src_mem, *context_.dst_mem));
2290 }
2291 };
2292
2293 template <typename T>
2294 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
2295 public:
Get(const memory * from,const memory * to)2296 static MklReorderPrimitive* Get(const memory* from, const memory* to) {
2297 auto reorderPrim = static_cast<MklReorderPrimitive*>(
2298 MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
2299 if (reorderPrim == nullptr) {
2300 reorderPrim = new MklReorderPrimitive(from, to);
2301 MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
2302 reorderPrim);
2303 }
2304 reorderPrim->SetMemory(from, to);
2305 return reorderPrim;
2306 }
2307
GetInstance()2308 static MklReorderPrimitiveFactory& GetInstance() {
2309 static MklReorderPrimitiveFactory instance_;
2310 return instance_;
2311 }
2312
2313 private:
MklReorderPrimitiveFactory()2314 MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory()2315 ~MklReorderPrimitiveFactory() {}
2316
CreateKey(const memory * from,const memory * to)2317 static string CreateKey(const memory* from, const memory* to) {
2318 string prefix = "reorder";
2319 FactoryKeyCreator key_creator;
2320 auto const& from_desc = from->get_primitive_desc().desc().data;
2321 auto const& to_desc = to->get_primitive_desc().desc().data;
2322 const int KIdxFirstStride = 0;
2323 memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
2324 memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
2325 memory::dims from_strides(
2326 from_desc.layout_desc.blocking.strides[KIdxFirstStride],
2327 &from_desc.layout_desc.blocking
2328 .strides[KIdxFirstStride][from_desc.ndims]);
2329 memory::dims to_strides(
2330 to_desc.layout_desc.blocking.strides[KIdxFirstStride],
2331 &to_desc.layout_desc.blocking.strides[KIdxFirstStride][to_desc.ndims]);
2332 key_creator.AddAsKey(prefix);
2333 key_creator.AddAsKey(static_cast<int>(from_desc.format));
2334 key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
2335 key_creator.AddAsKey(from_dims);
2336 key_creator.AddAsKey(from_strides);
2337 key_creator.AddAsKey(static_cast<int>(to_desc.format));
2338 key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
2339 key_creator.AddAsKey(to_dims);
2340 key_creator.AddAsKey(to_strides);
2341 return key_creator.GetKey();
2342 }
2343
GetReorder(const memory * from,const memory * to)2344 MklPrimitive* GetReorder(const memory* from, const memory* to) {
2345 string key = CreateKey(from, to);
2346 return this->GetOp(key);
2347 }
2348
SetReorder(const memory * from,const memory * to,MklPrimitive * op)2349 void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
2350 string key = CreateKey(from, to);
2351 this->SetOp(key, op);
2352 }
2353 };
2354
2355 /// Fuction to find(or create) a reorder from memory pointed by
2356 /// from to memory pointed by to, it will created primitive or
2357 /// get primitive from pool if it is cached.
2358 /// Returns the primitive.
2359 template <typename T>
FindOrCreateReorder(const memory * from,const memory * to)2360 inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
2361 CHECK_NOTNULL(from);
2362 CHECK_NOTNULL(to);
2363 MklReorderPrimitive* reorder_prim =
2364 MklReorderPrimitiveFactory<T>::Get(from, to);
2365 return *reorder_prim->GetPrimitive();
2366 }
2367
2368 // utility function to determine if it is conv 1x1 and stride != 1
2369 // for purpose of temporarily disabling primitive reuse
IsConv1x1StrideNot1(memory::dims filter_dims,memory::dims strides)2370 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2371 memory::dims strides) {
2372 if (filter_dims.size() != 4 || strides.size() != 2) return false;
2373
2374 return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2375 ((strides[0] != 1) || (strides[1] != 1)));
2376 }
2377
2378 #endif // INTEL_MKL_DNN
2379
2380 } // namespace tensorflow
2381 #endif // INTEL_MKL
2382 #endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2383