1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18 #ifdef INTEL_MKL
19
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26
27 #include "mkldnn.hpp"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/graph/mkl_graph_util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/util/env_var.h"
39 #include "tensorflow/core/util/mkl_threadpool.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42
43 using mkldnn::engine;
44 using mkldnn::memory;
45 using mkldnn::primitive;
46 using mkldnn::reorder;
47 using mkldnn::stream;
48 using CPUDevice = Eigen::ThreadPoolDevice;
49 using MemoryArgsMap = std::unordered_map<int, memory>;
50 using ReorderPd = mkldnn::reorder::primitive_desc;
51
52 #ifdef _WIN32
53 typedef unsigned int uint;
54 #endif
55
56 namespace tensorflow {
57
58 // The file contains a number of utility classes and functions used by MKL
59 // enabled kernels
60
61 // This class encapsulates all the meta data that is associated with an MKL
62 // tensor. A tensor is an MKL tensor if it was created as the result of an
63 // MKL operation, and did not go through a conversion to a standard
64 // Tensorflow tensor.
65
66 // The dimensions order that MKL-DNN internally uses for 2D activations
67 // [Batch, Channel, Height, Width] and
68 // for 2D filters [Out_Channel, In_Channel, Height, Width].
69 typedef enum {
70 Dim_N = 0,
71 Dim_C = 1,
72 Dim_H = 2,
73 Dim_W = 3,
74 Dim_O = 0,
75 Dim_I = 1
76 } MklDnnDims;
77
78 // The dimensions order that MKL-DNN internally uses for 3D activations
79 // [Batch, Channel, Depth, Height, Width] and
80 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
81 typedef enum {
82 Dim3d_N = 0,
83 Dim3d_C = 1,
84 Dim3d_D = 2,
85 Dim3d_H = 3,
86 Dim3d_W = 4,
87 Dim3d_O = 0,
88 Dim3d_I = 1
89 } MklDnnDims3D;
90
91 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
92 // filter_width, in_channels, out_channels]
93 typedef enum {
94 TF_2DFILTER_DIM_H = 0,
95 TF_2DFILTER_DIM_W = 1,
96 TF_2DFILTER_DIM_I = 2,
97 TF_2DFILTER_DIM_O = 3
98 } TFFilterDims2d;
99
100 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
101 // filter_height, filter_width, in_channels, out_channels]
102 typedef enum {
103 TF_3DFILTER_DIM_P = 0,
104 TF_3DFILTER_DIM_H = 1,
105 TF_3DFILTER_DIM_W = 2,
106 TF_3DFILTER_DIM_I = 3,
107 TF_3DFILTER_DIM_O = 4
108 } TFFilterDims3d;
109
110 // The dimensions order that MKL-DNN requires for the filter in a grouped
111 // convolution (2D only)
112 typedef enum {
113 MKL_GROUP_FILTER_DIM_G = 0,
114 MKL_GROUP_FILTER_DIM_O = 1,
115 MKL_GROUP_FILTER_DIM_I = 2,
116 MKL_GROUP_FILTER_DIM_H = 3,
117 MKL_GROUP_FILTER_DIM_W = 4
118 } MklDnnFilterGroupDims;
119
120 // Enum used to templatize MklOp kernel implementation
121 // that support both fp32 and int8 versions.
122 enum class MklQuantization {
123 QUANTIZED_VERSION,
124 FP_VERSION,
125 };
126
127 static const int kSmallBatchSize = 32;
128
execute_primitives(std::vector<mkldnn::primitive> & primitives,std::shared_ptr<stream> stream,std::vector<std::unordered_map<int,memory>> & net_args)129 inline void execute_primitives(
130 std::vector<mkldnn::primitive>& primitives, std::shared_ptr<stream> stream,
131 std::vector<std::unordered_map<int, memory>>& net_args) {
132 DCHECK_EQ(primitives.size(), net_args.size());
133 for (size_t i = 0; i < primitives.size(); ++i) {
134 primitives.at(i).execute(*stream, net_args.at(i));
135 }
136 }
137
138 // In MKL-DNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor
139 // (md) structure will no longer be recorded in its `format` field. Instead, it
140 // will be set to a canonical `blocked` format for every fully described md.
141 //
142 // Currently, we query this `format` field while mapping MKL-DNN's data format
143 // to TF's data format. Due to the above restriction, we will now get this data
144 // format information from TF's `data_format` attribute (i.e. via
145 // `TensorFormat`) for MKL-DNN v1.x.
146 //
147 // Some MKL-DNN operators such as ReLU do not have a `data_format` attribute
148 // since they are usually in `blocked` format. Therefore, in order to
149 // distinguish between blocked and non-blocked formats, we have defined a new
150 // enum called `MklTensorFormat` that is semantically similar to `TensorFormat`
151 // but with the following additional fields namely:
152 // 1) FORMAT_BLOCKED: as described above, this is needed for element-wise
153 // operators such as ReLU.
154 // 2) FORMAT_INVALID: for error-checking (ex. unsupported format)
155 // 3) FORMAT_X, FORMAT_NC, FORMAT_TNC: to distinguish between MKL tensors based
156 // on their dimensions in operators such as Softmax, i.e.:
157 // FORMAT_X - 1D tensor
158 // FORMAT_NC - 2D tensor
159 // FORMAT_TNC - 3D tensor
160 enum class MklTensorFormat {
161 FORMAT_NHWC = 0,
162 FORMAT_NCHW = 1,
163 FORMAT_NDHWC = 2,
164 FORMAT_NCDHW = 3,
165 FORMAT_X = 4,
166 FORMAT_NC = 5,
167 FORMAT_TNC = 6,
168 FORMAT_BLOCKED = 7,
169 FORMAT_INVALID = 8,
170 };
171
172 // Forward declarations
173 memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
174
175 TensorFormat MklDnn3DDataFormatToTFDataFormat(MklTensorFormat format);
176 TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format);
177
178 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
179 Status CreateBlockedMemDescHelper(const memory::dims& dim,
180 const memory::dims& strides,
181 memory::data_type dtype,
182 mkldnn_memory_desc_t* blocked_md);
183
184 inline std::ostream& operator<<(std::ostream& os,
185 const memory::format_tag& tag) {
186 if (tag == memory::format_tag::undef) {
187 os << "undef";
188 } else if (tag == memory::format_tag::any) {
189 os << "any";
190 } else {
191 os << "invalid";
192 }
193 return os;
194 }
195
196 inline void operator<<(std::ostream& os, const MklTensorFormat& format) {
197 if (format == MklTensorFormat::FORMAT_NHWC) {
198 os << "FORMAT_NHWC";
199 } else if (format == MklTensorFormat::FORMAT_NCHW) {
200 os << "FORMAT_NCHW";
201 } else if (format == MklTensorFormat::FORMAT_NDHWC) {
202 os << "FORMAT_NDHWC";
203 } else if (format == MklTensorFormat::FORMAT_NCDHW) {
204 os << "FORMAT_NCDHW";
205 } else if (format == MklTensorFormat::FORMAT_X) {
206 os << "FORMAT_X";
207 } else if (format == MklTensorFormat::FORMAT_NC) {
208 os << "FORMAT_NC";
209 } else if (format == MklTensorFormat::FORMAT_TNC) {
210 os << "FORMAT_TNC";
211 } else if (format == MklTensorFormat::FORMAT_BLOCKED) {
212 os << "FORMAT_BLOCKED";
213 } else {
214 os << "INVALID FORMAT";
215 }
216 }
217
218 template <typename T>
array_cmp(const T * a1,const T * a2,size_t size)219 inline bool array_cmp(const T* a1, const T* a2, size_t size) {
220 for (size_t i = 0; i < size; ++i)
221 if (a1[i] != a2[i]) return false;
222 return true;
223 }
224
CreateStream(MklDnnThreadPool * eigen_tp,const engine & engine)225 inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp,
226 const engine& engine) {
227 #ifndef ENABLE_ONEDNN_OPENMP
228 if (eigen_tp != nullptr) {
229 stream* tp_stream =
230 new stream(dnnl::threadpool_interop::make_stream(engine, eigen_tp));
231 return tp_stream;
232 } else {
233 stream* tp_stream = new stream(engine);
234 return tp_stream;
235 }
236 #else
237 stream* tp_stream = new stream(engine);
238 return tp_stream;
239 #endif // !ENABLE_ONEDNN_OPENMP
240 }
241
242 class MklDnnShape {
243 private:
244 struct MklShapeData {
245 // Flag to indicate if the tensor is an MKL tensor or not
246 bool is_mkl_tensor_ = false;
247 // Number of dimensions in Tensorflow format
248 size_t dimension_ = 0;
249 mkldnn_dims_t sizes_; // Required by MKL for conversions
250 MklTensorFormat tf_data_format_ = MklTensorFormat::FORMAT_BLOCKED;
251 memory::data_type T_ = memory::data_type::undef;
252 // MKL layout
253 mkldnn_memory_desc_t mkl_md_;
254 /// TF dimension corresponding to this MKL dimension
255 mkldnn_dims_t map_;
256 };
257 MklShapeData data_;
258
259 typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
260
261 #define INVALID_DIM_SIZE -1
262
263 public:
MklDnnShape()264 MklDnnShape() {
265 for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
266 ++i) {
267 data_.sizes_[i] = -1;
268 }
269 for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
270 data_.map_[i] = -1;
271 }
272 }
273
~MklDnnShape()274 ~MklDnnShape() {}
275 TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape); // Cannot copy
276
277 /// Equality function for MklDnnShape objects
278 /// @return true if both are equal; false otherwise.
279 inline bool operator==(const MklDnnShape& input_shape) const {
280 if (this->IsMklTensor() != input_shape.IsMklTensor()) {
281 return false;
282 }
283
284 // If input tensors are in MKL layout, then we check for dimensions and
285 // sizes.
286 if (this->IsMklTensor()) {
287 const mkldnn_memory_desc_t& cur_md = (this->GetMklLayout()).data;
288 const mkldnn_memory_desc_t& input_shape_md =
289 input_shape.GetMklLayout().data;
290 return this->GetTfShape() == input_shape.GetTfShape() &&
291 mkldnn_memory_desc_equal(&cur_md, &input_shape_md);
292 }
293
294 // Both inputs are not MKL tensors.
295 return true;
296 }
297
298 /// Equality operator for MklDnnShape and TFShape.
299 /// Returns: true if TF shapes for both are the same, false otherwise
300 inline bool operator==(const TensorShape& input_shape) const {
301 if (!this->IsMklTensor()) {
302 return false;
303 }
304
305 return this->GetTfShape() == input_shape;
306 }
307
IsMklTensor()308 inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
SetMklTensor(bool is_mkl_tensor)309 inline void SetMklTensor(bool is_mkl_tensor) {
310 data_.is_mkl_tensor_ = is_mkl_tensor;
311 }
312
SetDimensions(const size_t dimension)313 inline void SetDimensions(const size_t dimension) {
314 data_.dimension_ = dimension;
315 }
GetDimension(char dimension)316 inline size_t GetDimension(char dimension) const {
317 int index = GetMklDnnTensorDimIndex(dimension);
318 CHECK(index >= 0 && index < this->GetDimension())
319 << "Invalid index from the dimension: " << index << ", " << dimension;
320 return this->DimSize(index);
321 }
322
GetDimension3D(char dimension)323 inline size_t GetDimension3D(char dimension) const {
324 int index = GetMklDnnTensor3DDimIndex(dimension);
325 CHECK(index >= 0 && index < this->GetDimension())
326 << "Invalid index from the dimension: " << index << ", " << dimension;
327 return this->DimSize(index);
328 }
329
GetMklDnnTensorDimIndex(char dimension)330 inline int32 GetMklDnnTensorDimIndex(char dimension) const {
331 switch (dimension) {
332 case 'N':
333 return MklDnnDims::Dim_N;
334 case 'C':
335 return MklDnnDims::Dim_C;
336 case 'H':
337 return MklDnnDims::Dim_H;
338 case 'W':
339 return MklDnnDims::Dim_W;
340 default:
341 LOG(FATAL) << "Invalid dimension: " << dimension;
342 return -1; // Avoid compiler warning about missing return value
343 }
344 }
345
GetMklDnnTensor3DDimIndex(char dimension)346 inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
347 switch (dimension) {
348 case 'N':
349 return MklDnnDims3D::Dim3d_N;
350 case 'C':
351 return MklDnnDims3D::Dim3d_C;
352 case 'D':
353 return MklDnnDims3D::Dim3d_D;
354 case 'H':
355 return MklDnnDims3D::Dim3d_H;
356 case 'W':
357 return MklDnnDims3D::Dim3d_W;
358 default:
359 LOG(FATAL) << "Invalid dimension: " << dimension;
360 return -1; // Avoid compiler warning about missing return value
361 }
362 }
363
GetDimension()364 inline size_t GetDimension() const { return data_.dimension_; }
GetSizes()365 inline const int* GetSizes() const {
366 return reinterpret_cast<const int*>(&data_.sizes_[0]);
367 }
368
369 // Returns an mkldnn::memory::dims object that contains the sizes of this
370 // MklDnnShape object.
GetSizesAsMklDnnDims()371 inline memory::dims GetSizesAsMklDnnDims() const {
372 memory::dims retVal;
373 if (data_.is_mkl_tensor_) {
374 size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
375 for (size_t i = 0; i < dimensions; i++) {
376 if (data_.sizes_[i] != INVALID_DIM_SIZE)
377 retVal.push_back(data_.sizes_[i]);
378 }
379 } else {
380 CHECK_EQ(data_.is_mkl_tensor_, true);
381 }
382 return retVal;
383 }
384
DimSize(int index)385 inline int64 DimSize(int index) const {
386 CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
387 return data_.sizes_[index];
388 }
389
390 /// Return TensorShape that describes the Tensorflow shape of the tensor
391 /// represented by this MklShape.
GetTfShape()392 inline TensorShape GetTfShape() const {
393 CHECK_EQ(data_.is_mkl_tensor_, true);
394
395 std::vector<int32> shape(data_.dimension_, -1);
396 // As mentioned in the comment above, we now rely on TF's `data_format`
397 // attribute to determine if TF shape is in blocked format or not.
398 if (data_.tf_data_format_ != MklTensorFormat::FORMAT_BLOCKED) {
399 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
400 shape[idx] = data_.sizes_[TfDimIdx(idx)];
401 }
402 } else {
403 // If Tensorflow shape is in Blocked format, then we don't have dimension
404 // map for it. So we just create Tensorflow shape from sizes in the
405 // specified order.
406 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
407 shape[idx] = data_.sizes_[idx];
408 }
409 }
410
411 TensorShape ts;
412 bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
413 CHECK_EQ(ret, true);
414 return ts;
415 }
416
SetElemType(memory::data_type dt)417 inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
GetElemType()418 inline const memory::data_type GetElemType() { return data_.T_; }
419
SetMklLayout(memory::desc * md)420 inline void SetMklLayout(memory::desc* md) {
421 CHECK_NOTNULL(md);
422 data_.mkl_md_ = md->data;
423 }
424
GetMklLayout()425 inline const memory::desc GetMklLayout() const {
426 return memory::desc(data_.mkl_md_);
427 }
428
GetTfDataFormat()429 inline MklTensorFormat GetTfDataFormat() const {
430 return data_.tf_data_format_;
431 }
432
433 /// We don't create primitive_descriptor for TensorFlow layout now.
434 /// We use lazy evaluation and create it only when needed. Input format can
435 /// also be Blocked format.
SetTfLayout(size_t dims,const memory::dims & sizes,MklTensorFormat format)436 inline void SetTfLayout(size_t dims, const memory::dims& sizes,
437 MklTensorFormat format) {
438 DCHECK_EQ(dims, sizes.size())
439 << "SetTfLayout: Number of dimensions does not"
440 "match with dimension array";
441 data_.dimension_ = dims;
442 for (size_t ii = 0; ii < dims; ++ii) {
443 data_.sizes_[ii] = sizes[ii];
444 }
445 data_.tf_data_format_ = format;
446 if (format != MklTensorFormat::FORMAT_BLOCKED) {
447 if (dims == 2) {
448 data_.map_[0] = MklDnnDims::Dim_N;
449 data_.map_[1] = MklDnnDims::Dim_C;
450 } else {
451 SetTfDimOrder(dims, format);
452 }
453 }
454 }
455
GetTfLayout()456 inline const memory::desc GetTfLayout() const {
457 memory::dims dims;
458 for (size_t ii = 0; ii < data_.dimension_; ++ii) {
459 dims.push_back(data_.sizes_[ii]);
460 }
461
462 // Create Blocked memory desc if input TF format was set like that.
463 if (data_.tf_data_format_ == MklTensorFormat::FORMAT_BLOCKED) {
464 auto strides = CalculateTFStrides(dims);
465 mkldnn_memory_desc_t blocked_md;
466 TF_CHECK_OK(
467 CreateBlockedMemDescHelper(dims, strides, data_.T_, &blocked_md));
468 return memory::desc(blocked_md);
469 } else {
470 auto format_tag =
471 MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_);
472 return memory::desc(dims, data_.T_, format_tag);
473 }
474 }
475
GetCurLayout()476 inline const memory::desc GetCurLayout() const {
477 return IsMklTensor() ? GetMklLayout() : GetTfLayout();
478 }
479
480 // We don't need a case of default dimension order because
481 // when an operator that does not get data_format attribute gets all inputs
482 // in Tensorflow format, it will produce output in Tensorflow format.
SetTfDimOrder(const size_t dimension,const mkldnn_dims_t map)483 inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
484 CHECK(dimension == data_.dimension_);
485 for (size_t ii = 0; ii < dimension; ii++) {
486 data_.map_[ii] = map[ii];
487 }
488 }
489
SetTfDimOrder(const size_t dimension,TensorFormat data_format)490 inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
491 if (dimension == 5) {
492 CHECK(dimension == data_.dimension_);
493 data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
494 MklDnnDims3D::Dim3d_D;
495 data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
496 MklDnnDims3D::Dim3d_H;
497 data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
498 MklDnnDims3D::Dim3d_W;
499 data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
500 MklDnnDims3D::Dim3d_C;
501 data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
502 MklDnnDims3D::Dim3d_N;
503 } else {
504 CHECK_EQ(dimension, 4);
505 CHECK(dimension == data_.dimension_);
506 data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
507 data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
508 data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
509 data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
510 }
511 }
512
SetTfDimOrder(const size_t dimension,MklTensorFormat format)513 inline void SetTfDimOrder(const size_t dimension, MklTensorFormat format) {
514 TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
515 SetTfDimOrder(dimension, data_format);
516 }
517
GetTfToMklDimMap()518 inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
TfDimIdx(int index)519 inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
TfDimSize(int index)520 inline int64 TfDimSize(int index) const {
521 return data_.sizes_[TfDimIdx(index)];
522 }
523
524 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
525 /// corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)526 inline bool IsMklChannelDim(int d) const {
527 return TfDimIdx(d) == MklDnnDims::Dim_C;
528 }
529
530 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
531 /// corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)532 inline bool IsMklBatchDim(int d) const {
533 return TfDimIdx(d) == MklDnnDims::Dim_N;
534 }
535
536 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
537 /// corresponds to MKL's Width dimension.
IsMklWidthDim(int d)538 inline bool IsMklWidthDim(int d) const {
539 return TfDimIdx(d) == MklDnnDims::Dim_W;
540 }
541 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
542 /// corresponds to MKL's Height dimension.
IsMklHeightDim(int d)543 inline bool IsMklHeightDim(int d) const {
544 return TfDimIdx(d) == MklDnnDims::Dim_H;
545 }
546
547 /// Check if the TF-MKL dimension ordering map specifies if the input
548 /// tensor is in NCHW format.
IsTensorInNCHWFormat()549 inline bool IsTensorInNCHWFormat() const {
550 TensorFormat data_format = FORMAT_NCHW;
551 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
552 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
553 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
554 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
555 }
556
557 /// Check if the TF-MKL dimension ordering map specifies if the input
558 /// tensor is in NHWC format.
IsTensorInNHWCFormat()559 inline bool IsTensorInNHWCFormat() const {
560 TensorFormat data_format = FORMAT_NHWC;
561 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
562 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
563 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
564 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
565 }
566
567 /// The following methods are used for serializing and de-serializing the
568 /// contents of the mklshape object.
569 /// The data is serialized in this order
570 /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
571
572 /// Size of buffer to hold the serialized object, the size is computed by
573 /// following above mentioned order
GetSerializeBufferSize()574 inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
575
SerializeMklDnnShape(unsigned char * buf,size_t buf_size)576 void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
577 CHECK(buf_size >= GetSerializeBufferSize())
578 << "Buffer size is too small to SerializeMklDnnShape";
579 *reinterpret_cast<MklShapeData*>(buf) = data_;
580 }
581
DeSerializeMklDnnShape(const unsigned char * buf,size_t buf_size)582 void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
583 // Make sure buffer holds at least is_mkl_tensor_.
584 CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
585 << "Buffer size is too small in DeSerializeMklDnnShape";
586
587 const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
588 if (is_mkl_tensor) { // If it is an MKL Tensor then read the rest
589 CHECK(buf_size >= GetSerializeBufferSize())
590 << "Buffer size is too small in DeSerializeMklDnnShape";
591 data_ = *reinterpret_cast<const MklShapeData*>(buf);
592 }
593 }
594 };
595
596 // List of MklShape objects. Used in Concat/Split layers.
597 typedef std::vector<MklDnnShape> MklDnnShapeList;
598
599 template <typename T>
600 class MklDnnData;
601
602 // TODO merge with the execute_primitives.
603 inline void ExecutePrimitive(const std::vector<primitive>& net,
604 const std::vector<MemoryArgsMap>* net_args,
605 const engine& cpu_engine,
606 OpKernelContext* context = nullptr) {
607 DCHECK(net_args);
608 DCHECK_EQ(net.size(), net_args->size());
609 std::unique_ptr<stream> cpu_stream;
610 MklDnnThreadPool eigen_tp;
611 if (context != nullptr) {
612 eigen_tp = MklDnnThreadPool(context);
613 cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));
614 } else {
615 cpu_stream.reset(CreateStream(nullptr, cpu_engine));
616 }
617 for (size_t i = 0; i < net.size(); ++i) {
618 net.at(i).execute(*cpu_stream, net_args->at(i));
619 }
620 cpu_stream->wait();
621 }
622 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & input_mkl_tensor,const MklDnnShape & input_mkl_shape,Tensor * output_tf_tensor)623 inline Status ConvertMklToTF(OpKernelContext* context,
624 const Tensor& input_mkl_tensor,
625 const MklDnnShape& input_mkl_shape,
626 Tensor* output_tf_tensor) {
627 try {
628 if (!input_mkl_shape.IsMklTensor()) {
629 // Return input as is since it is already a TF tensor
630 *output_tf_tensor = input_mkl_tensor;
631 return Status::OK();
632 }
633
634 // Allocate output tensor.
635 TensorShape output_tf_shape = input_mkl_shape.GetTfShape();
636 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<T>::v(), output_tf_shape,
637 output_tf_tensor));
638
639 engine cpu_engine(engine::kind::cpu, 0);
640 MklDnnData<T> input(&cpu_engine);
641
642 // Get MKL layout of input tensor.
643 auto input_mkl_md = input_mkl_shape.GetMklLayout();
644 auto output_tf_md = input_mkl_shape.GetTfLayout();
645 input.SetUsrMem(input_mkl_md, &input_mkl_tensor);
646
647 if (input.IsReorderNeeded(output_tf_md)) {
648 std::vector<primitive> net;
649 std::vector<MemoryArgsMap> net_args;
650 bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor,
651 net, net_args, cpu_engine);
652 if (!status) {
653 return Status(error::Code::INTERNAL,
654 "ConvertMklToTF(): Failed to create reorder for input");
655 }
656 ExecutePrimitive(net, &net_args, cpu_engine, context);
657 } else {
658 // If not, just forward input tensor to output tensor.
659 bool status =
660 output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape);
661 if (!status) {
662 return Status(
663 error::Code::INTERNAL,
664 "ConvertMklToTF(): Failed to forward input tensor to output");
665 }
666 }
667 return Status::OK();
668 } catch (mkldnn::error& e) {
669 string error_msg = "Status: " + std::to_string(e.status) +
670 ", message: " + string(e.message) + ", in file " +
671 string(__FILE__) + ":" + std::to_string(__LINE__);
672 LOG(FATAL) << "Operation received an exception: " << error_msg;
673 }
674 }
675
676 // Get the MKL shape from the second string tensor
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape,bool eager_mode)677 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
678 bool eager_mode) {
679 if (!eager_mode) {
680 mklshape->DeSerializeMklDnnShape(
681 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
682 .flat<uint8>()
683 .data(),
684 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
685 .flat<uint8>()
686 .size() *
687 sizeof(uint8));
688 } else {
689 mklshape->SetMklTensor(false);
690 }
691 }
692
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape)693 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
694 GetMklShape(ctext, n, mklshape, false);
695 }
696
697 // Gets the actual input
MklGetInput(OpKernelContext * ctext,int n)698 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
699 return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
700 }
701
GetMklInputList(OpKernelContext * ctext,StringPiece name,OpInputList * input_tensors)702 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
703 OpInputList* input_tensors) {
704 CHECK_NOTNULL(input_tensors);
705 TF_CHECK_OK(ctext->input_list(name, input_tensors));
706 }
707
708 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
709 MklDnnShapeList* mkl_shapes,
710 bool native_format = false) {
711 if (!native_format) {
712 OpInputList input_mkl_tensors;
713 GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
714
715 for (int i = 0; i < input_mkl_tensors.size(); i++) {
716 (*mkl_shapes)[i].DeSerializeMklDnnShape(
717 input_mkl_tensors[i].flat<uint8>().data(),
718 input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
719 }
720 } else {
721 for (int i = 0; i < mkl_shapes->size(); ++i) {
722 (*mkl_shapes)[i].SetMklTensor(false);
723 }
724 }
725 }
726
727 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
728 /// If the input tensor is in MKL layout, then obtains TensorShape from
729 /// MklShape.
730 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
731 bool eager_mode = false) {
732 // Sanity check.
733 CHECK_NOTNULL(context);
734 CHECK_LT(input_idx, context->num_inputs());
735
736 MklDnnShape input_mkl_shape;
737 GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
738 if (input_mkl_shape.IsMklTensor() && !eager_mode) {
739 return input_mkl_shape.GetTfShape();
740 } else {
741 const Tensor& t = MklGetInput(context, input_idx);
742 return t.shape();
743 }
744 }
745
746 // Allocate the second output tensor that will contain
747 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklDnnShape & mkl_shape)748 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
749 const MklDnnShape& mkl_shape) {
750 Tensor* second_tensor = nullptr;
751 TensorShape second_shape;
752 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
753 OP_REQUIRES_OK(ctext, ctext->allocate_output(
754 GetTensorMetaDataIndex(n, ctext->num_outputs()),
755 second_shape, &second_tensor));
756 mkl_shape.SerializeMklDnnShape(
757 second_tensor->flat<uint8>().data(),
758 second_tensor->flat<uint8>().size() * sizeof(uint8));
759 }
760
761 // Allocate the output tensor, create a second output tensor that will contain
762 // the MKL shape serialized
763 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
764 Tensor** output,
765 const TensorShape& tf_shape,
766 const MklDnnShape& mkl_shape,
767 bool eager_mode = false) {
768 OP_REQUIRES_OK(
769 ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
770 tf_shape, output));
771 if (!eager_mode) {
772 Tensor* second_tensor = nullptr;
773 TensorShape second_shape;
774 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
775 OP_REQUIRES_OK(ctext, ctext->allocate_output(
776 GetTensorMetaDataIndex(n, ctext->num_outputs()),
777 second_shape, &second_tensor));
778 mkl_shape.SerializeMklDnnShape(
779 second_tensor->flat<uint8>().data(),
780 second_tensor->flat<uint8>().size() * sizeof(uint8));
781 }
782 }
783
784 // Allocates a temp tensor and returns the data buffer for temporary storage.
785 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,const memory::desc & pd,void ** buf_out)786 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
787 const memory::desc& pd, void** buf_out) {
788 TensorShape tf_shape;
789
790 tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
791 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
792 tf_shape, tensor_out));
793 *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
794 }
795
796 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,TensorShape tf_shape)797 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
798 TensorShape tf_shape) {
799 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
800 tf_shape, tensor_out));
801 }
802
GetStridesFromSizes(MklTensorFormat data_format,size_t * strides,const size_t * sizes)803 inline void GetStridesFromSizes(MklTensorFormat data_format, size_t* strides,
804 const size_t* sizes) {
805 DCHECK_NE(data_format, MklTensorFormat::FORMAT_INVALID);
806 // MKL requires strides in NCHW
807 if (data_format == MklTensorFormat::FORMAT_NHWC) {
808 strides[0] = sizes[2];
809 strides[1] = sizes[0] * sizes[2];
810 strides[2] = 1;
811 strides[3] = sizes[0] * sizes[1] * sizes[2];
812 } else {
813 strides[0] = 1;
814 strides[1] = sizes[0];
815 strides[2] = sizes[0] * sizes[1];
816 strides[3] = sizes[0] * sizes[1] * sizes[2];
817 }
818 }
819
CopyMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)820 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
821 int idx_out) {
822 int num_inputs = context->num_inputs();
823 int num_outputs = context->num_outputs();
824 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
825 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
826 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
827 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
828
829 const Tensor& data = context->input(idx_data_in);
830 const Tensor& meta = context->input(idx_meta_in);
831 Tensor output(data.dtype());
832 Tensor meta_output(meta.dtype());
833
834 // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
835 CHECK(output.CopyFrom(data, data.shape()));
836 CHECK(meta_output.CopyFrom(meta, meta.shape()));
837 context->set_output(idx_data_out, output);
838 context->set_output(idx_meta_out, meta_output);
839 }
840
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)841 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
842 int idx_out,
843 const TensorShape& shape) {
844 int num_inputs = context->num_inputs();
845 int num_outputs = context->num_outputs();
846 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
847 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
848
849 const Tensor& data = context->input(idx_data_in);
850 MklDnnShape mkl_shape_output;
851 mkl_shape_output.SetMklTensor(false);
852 AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
853 Tensor output(data.dtype());
854 // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
855 CHECK(output.CopyFrom(data, shape));
856 context->set_output(idx_data_out, output);
857 }
858
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)859 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
860 int idx_out) {
861 int num_inputs = context->num_inputs();
862 int num_outputs = context->num_outputs();
863 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
864 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
865
866 MklDnnShape dnn_shape_output;
867 dnn_shape_output.SetMklTensor(false);
868 AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
869 if (IsRefType(context->input_dtype(idx_data_in))) {
870 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
871 } else {
872 context->set_output(idx_data_out, context->input(idx_data_in));
873 }
874 }
875
ForwardMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)876 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
877 int idx_out) {
878 int num_inputs = context->num_inputs();
879 int num_outputs = context->num_outputs();
880 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
881 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
882 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
883 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
884
885 if (IsRefType(context->input_dtype(idx_data_in))) {
886 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
887 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
888 } else {
889 context->set_output(idx_data_out, context->input(idx_data_in));
890 context->set_output(idx_meta_out, context->input(idx_meta_in));
891 }
892 }
893
894 // Set a dummy MKLDNN shape (called when the output is in TF format)
SetDummyMklDnnShapeOutput(OpKernelContext * context,uint32 idx_data_out)895 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
896 uint32 idx_data_out) {
897 MklDnnShape mkl_shape_output;
898 mkl_shape_output.SetMklTensor(false);
899 AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
900 }
901
902 // If the input tensor has ref count as 1, it is forwarded to the desired
903 // output port and the function returns true. In that case, it also allocates
904 // the serialized MklDnnShape object. Otherwise, the function returns false.
905 inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
906 int idx_in, int idx_out,
907 Tensor** output,
908 const MklDnnShape& mkl_shape,
909 bool always_forward = true) {
910 int num_inputs = context->num_inputs();
911 int num_outputs = context->num_outputs();
912 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
913 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
914 bool is_forwarded = false;
915 const Tensor& input_tensor = context->input(idx_data_in);
916 const auto output_shape = input_tensor.shape();
917 if (always_forward) {
918 if (IsRefType(context->input_dtype(idx_data_in))) {
919 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
920 } else {
921 context->set_output(idx_data_out, input_tensor);
922 }
923 } else {
924 is_forwarded = context->forward_input_to_output_with_shape(
925 idx_data_in, idx_data_out, output_shape, output);
926 }
927 if (is_forwarded || always_forward) {
928 AllocateOutputSetMklShape(context, idx_out, mkl_shape);
929 return true;
930 }
931 return false;
932 }
933
934 // Forward the MKL shape ONLY (used in elementwise and other ops where
935 // we call the eigen implementation and MKL shape is not used)
ForwardMklMetaDataInToOut(OpKernelContext * context,uint32 idx_data_in,uint32_t idx_data_out)936 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
937 uint32 idx_data_in,
938 uint32_t idx_data_out) {
939 uint32 idx_meta_in =
940 GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
941 uint32 idx_meta_out =
942 GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
943
944 if (IsRefType(context->input_dtype(idx_data_in))) {
945 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
946 } else {
947 context->set_output(idx_meta_out, context->input(idx_meta_in));
948 }
949 }
950
951 // -------------------------------------------------------------------
952 // Common utility functions used by MKL unit tests
953
GetMklMetaTensor()954 inline Tensor GetMklMetaTensor() {
955 MklDnnShape non_mkl_shape;
956 non_mkl_shape.SetMklTensor(false);
957
958 auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
959 Tensor tensor(DT_UINT8, {size});
960
961 non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
962 size * sizeof(uint8));
963 return tensor;
964 }
965
966 // -------------------------------------------------------------------
967
968 /// Return MKL-DNN data type (memory::data_type) for input type T
969 ///
970 /// @input None
971 /// @return memory::data_type corresponding to type T
972 template <typename T>
973 static memory::data_type MklDnnType();
974
975 /// Instantiation for float type. Add similar instantiations for other
976 /// type if needed.
977 template <>
978 memory::data_type MklDnnType<float>() {
979 return memory::data_type::f32;
980 }
981
982 template <>
983 memory::data_type MklDnnType<quint8>() {
984 return memory::data_type::u8;
985 }
986
987 template <>
988 memory::data_type MklDnnType<uint8>() {
989 return memory::data_type::u8;
990 }
991
992 template <>
993 memory::data_type MklDnnType<qint8>() {
994 return memory::data_type::s8;
995 }
996
997 template <>
998 memory::data_type MklDnnType<qint32>() {
999 return memory::data_type::s32;
1000 }
1001 template <>
1002 memory::data_type MklDnnType<bfloat16>() {
1003 return memory::data_type::bf16;
1004 }
1005
1006 // Map MklTensorFormat to MKL-DNN format tag
1007 //
1008 // @input: MklTensorFormat i.e. TensorFlow data format
1009 // @return: MKL-DNN's memory format tag corresponding to MklTensorFormat.
1010 // Fails with an error if invalid data format.
MklTensorFormatToMklDnnDataFormat(MklTensorFormat format)1011 inline memory::format_tag MklTensorFormatToMklDnnDataFormat(
1012 MklTensorFormat format) {
1013 if (format == MklTensorFormat::FORMAT_NHWC) return memory::format_tag::nhwc;
1014 if (format == MklTensorFormat::FORMAT_NCHW) return memory::format_tag::nchw;
1015 if (format == MklTensorFormat::FORMAT_NDHWC) return memory::format_tag::ndhwc;
1016 if (format == MklTensorFormat::FORMAT_NCDHW) return memory::format_tag::ncdhw;
1017 if (format == MklTensorFormat::FORMAT_X) return memory::format_tag::x;
1018 if (format == MklTensorFormat::FORMAT_NC) return memory::format_tag::nc;
1019 if (format == MklTensorFormat::FORMAT_TNC) return memory::format_tag::tnc;
1020 return memory::format_tag::undef;
1021 }
1022
1023 /// Map TensorFlow data format into MKL-DNN 3D data format
1024 /// @input: TensorFlow data format
1025 /// @return: MKL-DNN 3D data format corresponding to TensorFlow data format;
1026 /// Fails with an error if invalid data format.
TFDataFormatToMklDnn3DDataFormat(TensorFormat format)1027 inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1028 if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC;
1029 if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW;
1030 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1031 return MklTensorFormat::FORMAT_INVALID;
1032 }
1033
1034 /// Map TensorFlow data format into MKL-DNN data format
1035 ///
1036 /// @input: TensorFlow data format
1037 /// @return: MKL-DNN data format corresponding to TensorFlow data format;
1038 /// Fails with an error if invalid data format.
TFDataFormatToMklDnnDataFormat(TensorFormat format)1039 inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1040 if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC;
1041 if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW;
1042 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1043 return MklTensorFormat::FORMAT_INVALID;
1044 }
1045
1046 /// Map MKL-DNN data format into TensorFlow data format
1047 ///
1048 /// @input: MKL-DNN data format
1049 /// @return: Tensorflow data format corresponding to MKL-DNN data format;
1050 /// Fails with an error if invalid data format.
MklDnnDataFormatToTFDataFormat(MklTensorFormat format)1051 inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) {
1052 if (format == MklTensorFormat::FORMAT_NHWC ||
1053 format == MklTensorFormat::FORMAT_NDHWC)
1054 return FORMAT_NHWC;
1055 if (format == MklTensorFormat::FORMAT_NCHW ||
1056 format == MklTensorFormat::FORMAT_NCDHW)
1057 return FORMAT_NCHW;
1058 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1059
1060 // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1061 // that we don't come here.
1062 return FORMAT_NHWC;
1063 }
1064
1065 /// Map TensorShape object into memory::dims required by MKL-DNN
1066 ///
1067 /// This function will simply map input TensorShape into MKL-DNN dims
1068 /// naively. So it will preserve the order of dimensions. E.g., if
1069 /// input tensor is in NHWC format, then dims will be in NHWC format also.
1070 ///
1071 /// @input TensorShape object in shape
1072 /// @return memory::dims corresponding to TensorShape
TFShapeToMklDnnDims(const TensorShape & shape)1073 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1074 memory::dims dims(shape.dims());
1075 for (int d = 0; d < shape.dims(); ++d) {
1076 dims[d] = shape.dim_size(d);
1077 }
1078 return dims;
1079 }
1080
1081 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
1082 ///
1083 /// This function is a specific one than above function. It will map input
1084 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
1085 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1086 /// will be in NCHW format, and not in NHWC format.
1087 ///
1088 /// @input TensorShape object in shape
1089 /// @return memory::dims in MKL-DNN required NCHW format
TFShapeToMklDnnDimsInNCHW(const TensorShape & shape,TensorFormat format)1090 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1091 TensorFormat format) {
1092 // Check validity of format.
1093 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1094 MklTensorFormat::FORMAT_INVALID);
1095
1096 int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1097 int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1098 int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1099 int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1100
1101 // MKL-DNN requires dimensions in NCHW format.
1102 return memory::dims({n, c, h, w});
1103 }
1104
TFShapeToMklDnnDimsInNCDHW(const TensorShape & shape,TensorFormat format)1105 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1106 TensorFormat format) {
1107 // Validate format.
1108 DCHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1109 MklTensorFormat::FORMAT_INVALID);
1110
1111 int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1112 int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1113 int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1114 int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1115 int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1116
1117 // MKL-DNN requires dimensions in NCDHW format.
1118 return memory::dims({n, c, d, h, w});
1119 }
1120
1121 /// Overloaded version of function TFShapeToMklDnnDimsInNCHW above.
1122 /// Input parameters are self-explanatory.
MklDnnDimsInNCHW(const memory::dims & in_dims,TensorFormat format)1123 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1124 TensorFormat format) {
1125 // Validate format.
1126 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1127 MklTensorFormat::FORMAT_INVALID);
1128
1129 int n = in_dims[GetTensorDimIndex(format, 'N')];
1130 int c = in_dims[GetTensorDimIndex(format, 'C')];
1131 int h = in_dims[GetTensorDimIndex(format, 'H')];
1132 int w = in_dims[GetTensorDimIndex(format, 'W')];
1133
1134 // MKL-DNN requires dimensions in NCHW format.
1135 return memory::dims({n, c, h, w});
1136 }
1137
1138 /// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above.
1139 /// Input parameters are self-explanatory.
MklDnnDimsInNCDHW(const memory::dims & in_dims,TensorFormat format)1140 inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims,
1141 TensorFormat format) {
1142 // Validate format.
1143 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1144 MklTensorFormat::FORMAT_INVALID);
1145
1146 int n = in_dims[GetTensorDimIndex<3>(format, 'N')];
1147 int c = in_dims[GetTensorDimIndex<3>(format, 'C')];
1148 int d = in_dims[GetTensorDimIndex<3>(format, '0')];
1149 int h = in_dims[GetTensorDimIndex<3>(format, '1')];
1150 int w = in_dims[GetTensorDimIndex<3>(format, '2')];
1151
1152 // MKL DNN requires dimensions in NCDHW format.
1153 return memory::dims({n, c, d, h, w});
1154 }
1155
1156 /// Map MklDnn memory::dims object into TensorShape object.
1157 ///
1158 /// This function will simply map input shape in MKL-DNN memory::dims format
1159 /// in Tensorflow's TensorShape object by preserving dimension order.
1160 ///
1161 /// @input MKL-DNN memory::dims object
1162 /// @output TensorShape corresponding to memory::dims
MklDnnDimsToTFShape(const memory::dims & dims)1163 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1164 std::vector<int32> shape(dims.size(), -1);
1165 for (int d = 0; d < dims.size(); d++) {
1166 shape[d] = dims[d];
1167 }
1168
1169 TensorShape ret;
1170 CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1171 return ret;
1172 }
1173
1174 /// Function to calculate strides given tensor shape in Tensorflow order
1175 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1176 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
1177 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1178 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1179 ///
1180 /// @input Tensorflow shape in memory::dims type
1181 /// @return memory::dims containing strides for the tensor.
CalculateTFStrides(const memory::dims & dims_tf_order)1182 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1183 CHECK_GT(dims_tf_order.size(), 0);
1184 memory::dims strides(dims_tf_order.size());
1185 int last_dim_idx = dims_tf_order.size() - 1;
1186 strides[last_dim_idx] = 1;
1187 for (int d = last_dim_idx - 1; d >= 0; d--) {
1188 strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1189 }
1190 return strides;
1191 }
1192
1193 /// Helper function to create memory descriptor in Blocked format
1194 ///
1195 /// @input: Tensor dimensions
1196 /// @input: strides corresponding to dimensions. One can use utility
1197 /// function such as CalculateTFStrides to compute strides
1198 /// for given dimensions.
1199 /// @output: mkldnn_memory_desc_t object corresponding to blocked memory
1200 /// format for given dimensions and strides.
1201 /// @return: Status indicating whether the blocked memory descriptor
1202 /// was successfully created.
CreateBlockedMemDescHelper(const memory::dims & dim,const memory::dims & strides,memory::data_type dtype,mkldnn_memory_desc_t * blocked_md)1203 inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
1204 const memory::dims& strides,
1205 memory::data_type dtype,
1206 mkldnn_memory_desc_t* blocked_md) {
1207 DCHECK_EQ(dim.size(), strides.size());
1208 const int kNumDims = dim.size();
1209 mkldnn_dim_t* input_dims = new mkldnn_dim_t[kNumDims];
1210 mkldnn_dim_t* input_strides = new mkldnn_dim_t[kNumDims];
1211 for (int i = 0; i < kNumDims; ++i) {
1212 input_dims[i] = dim[i];
1213 input_strides[i] = strides[i];
1214 }
1215 try {
1216 mkldnn_memory_desc_init_by_strides(blocked_md, kNumDims, input_dims,
1217 memory::convert_to_c(dtype),
1218 input_strides);
1219 delete[] input_dims;
1220 delete[] input_strides;
1221 } catch (mkldnn::error& e) {
1222 delete[] input_dims;
1223 delete[] input_strides;
1224 return Status(error::Code::INTERNAL,
1225 tensorflow::strings::StrCat(
1226 "Failed to create blocked memory descriptor.",
1227 "Status: ", e.status, ", message: ", e.message));
1228 }
1229 return Status::OK();
1230 }
1231
1232 inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
1233 const memory& src_mem,
1234 const memory& dst_mem, const engine& engine,
1235 OpKernelContext* ctx = nullptr) {
1236 std::vector<primitive> net;
1237 net.push_back(mkldnn::reorder(reorder_desc));
1238 std::vector<MemoryArgsMap> net_args;
1239 net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
1240 ExecutePrimitive(net, &net_args, engine, ctx);
1241 }
1242
1243 class MklReorderPrimitive;
1244
1245 template <typename T>
1246 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1247 const memory* to);
1248
1249 // Class to represent all the resources corresponding to a tensor in TensorFlow
1250 // that are required to execute an operation (such as Convolution).
1251 template <typename T>
1252 class MklDnnData {
1253 private:
1254 /// MKL-DNN memory primitive for input user memory
1255 memory* user_memory_;
1256
1257 /// MKL-DNN memory primitive in case input or output reorder is needed.
1258 memory* reorder_memory_;
1259
1260 /// Operations memory descriptor
1261 memory::desc* op_md_;
1262 // flat to indicate if data is 3D or not.
1263 bool bIs3D;
1264 /// Operations temp buffer
1265 void* allocated_buffer_;
1266 /// CPU engine on which operation will be executed
1267 const engine* cpu_engine_;
1268
1269 public:
MklDnnData(const engine * e)1270 explicit MklDnnData(const engine* e)
1271 : user_memory_(nullptr),
1272 reorder_memory_(nullptr),
1273 op_md_(nullptr),
1274 bIs3D(false),
1275 allocated_buffer_(nullptr),
1276 cpu_engine_(e) {}
1277
~MklDnnData()1278 ~MklDnnData() {
1279 if (allocated_buffer_ != nullptr) {
1280 cpu_allocator()->DeallocateRaw(allocated_buffer_);
1281 }
1282 cpu_engine_ = nullptr; // We don't own this.
1283 delete (user_memory_);
1284 delete (reorder_memory_);
1285 delete (op_md_);
1286 }
1287
GetTensorBuffer(const Tensor * tensor)1288 inline void* GetTensorBuffer(const Tensor* tensor) const {
1289 CHECK_NOTNULL(tensor);
1290 return const_cast<void*>(
1291 static_cast<const void*>(tensor->flat<T>().data()));
1292 }
1293
SetIs3DData(bool bIs3D_)1294 void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
GetIs3D()1295 bool GetIs3D() { return bIs3D; }
1296
1297 /// Set user memory primitive using specified dimensions, memory format tag
1298 /// and data_buffer. Function automatically uses element data type by using
1299 /// input type T used for creating call object.
1300 ///
1301 /// In a nutshell, function allows user to describe the input tensor to
1302 /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1303 /// memory format tag HWIO, and the buffer that contains actual values is
1304 /// pointed by data_buffer.
1305 inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1306 void* data_buffer = nullptr) {
1307 auto md = memory::desc(dim, MklDnnType<T>(), fm);
1308 SetUsrMem(md, data_buffer);
1309 }
1310
SetUsrMem(const memory::dims & dim,memory::format_tag fm,const Tensor * tensor)1311 inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1312 const Tensor* tensor) {
1313 DCHECK(tensor);
1314 SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1315 }
1316
1317 /// Helper function to create memory descriptor in Blocked format
1318 ///
1319 /// @input: Tensor dimensions
1320 /// @input: strides corresponding to dimensions. One can use utility
1321 /// function such as CalculateTFStrides to compute strides
1322 /// for given dimensions.
1323 /// @return: memory::desc object corresponding to blocked memory format
1324 /// for given dimensions and strides.
CreateBlockedMemDesc(const memory::dims & dim,const memory::dims & strides)1325 static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1326 const memory::dims& strides) {
1327 mkldnn_memory_desc_t blocked_md;
1328 TF_CHECK_OK(
1329 CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>(), &blocked_md));
1330 return memory::desc(blocked_md);
1331 }
1332
1333 /// A version of SetUsrMem call that allows user to create memory in blocked
1334 /// format. So in addition to accepting dimensions, it also accepts strides.
1335 /// This allows user to create memory for tensor in a format that is not
1336 /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
1337 /// dimensional tensor as a native format. But by using blocked format, a user
1338 /// can create memory for 6D tensor.
1339 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1340 void* data_buffer = nullptr) {
1341 CHECK_EQ(dim.size(), strides.size());
1342 auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1343 SetUsrMem(blocked_md, data_buffer);
1344 }
1345
SetUsrMem(const memory::dims & dim,const memory::dims & strides,const Tensor * tensor)1346 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1347 const Tensor* tensor) {
1348 CHECK_NOTNULL(tensor);
1349 SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1350 }
1351
1352 /// A version of SetUsrMem with memory descriptor and tensor
SetUsrMem(const memory::desc & md,const Tensor * tensor)1353 inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1354 CHECK_NOTNULL(tensor);
1355 SetUsrMem(md, GetTensorBuffer(tensor));
1356 }
1357
1358 /// A version of function to set user memory type that accepts memory
1359 /// descriptor directly, instead of accepting dimensions and format. This
1360 /// function is more generic than the one above, but the function above is
1361 /// sufficient in most cases.
1362 inline void SetUsrMem(const memory::desc& pd, void* data_buffer = nullptr) {
1363 DCHECK(cpu_engine_);
1364 if (user_memory_) delete user_memory_;
1365 // TODO(nhasabni): can we remove dynamic memory allocation?
1366 if (data_buffer) {
1367 user_memory_ = new memory(pd, *cpu_engine_, data_buffer);
1368 } else {
1369 user_memory_ = new memory(pd, *cpu_engine_);
1370 }
1371 }
1372
1373 /// Get function for user memory primitive.
GetUsrMem()1374 inline const memory* GetUsrMem() const { return user_memory_; }
1375
1376 /// Get function for descriptor of user memory.
GetUsrMemDesc()1377 inline memory::desc GetUsrMemDesc() const {
1378 DCHECK(user_memory_);
1379 return user_memory_->get_desc();
1380 }
1381
1382 /// Get function for data buffer of user memory primitive.
GetUsrMemDataHandle()1383 inline void* GetUsrMemDataHandle() const {
1384 CHECK_NOTNULL(user_memory_);
1385 return user_memory_->get_data_handle();
1386 }
1387
1388 /// Set function for data buffer of user memory primitive.
1389 inline void SetUsrMemDataHandle(void* data_buffer,
1390 std::shared_ptr<stream> t_stream = nullptr) {
1391 CHECK_NOTNULL(user_memory_);
1392 CHECK_NOTNULL(data_buffer);
1393 #ifndef ENABLE_ONEDNN_OPENMP
1394 user_memory_->set_data_handle(data_buffer, *t_stream);
1395 #else
1396 user_memory_->set_data_handle(data_buffer);
1397 #endif // !ENABLE_ONEDNN_OPENMP
1398 }
1399
1400 /// Set function for data buffer of user memory primitive.
1401 inline void SetUsrMemDataHandle(const Tensor* tensor,
1402 std::shared_ptr<stream> t_stream = nullptr) {
1403 SetUsrMemDataHandle(GetTensorBuffer(tensor), t_stream);
1404 }
1405
1406 /// allocate function for data buffer
AllocateBuffer(size_t size)1407 inline void AllocateBuffer(size_t size) {
1408 const int64 kMemoryAlignment = 64; // For AVX512 memory alignment.
1409 allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlignment, size);
1410 }
1411
GetAllocatedBuffer()1412 inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1413
1414 /// Get the memory primitive for input and output of an op. If inputs
1415 /// to an op require reorders, then this function returns memory primitive
1416 /// for reorder. Otherwise, it will return memory primitive for user memory.
1417 ///
1418 /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1419 /// execute Conv2D, we need memory primitive for I and F. But if reorder is
1420 /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1421 /// primitive for F), then we need I_r and F_r to perform Conv2D.
GetOpMem()1422 inline const memory& GetOpMem() const {
1423 return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1424 }
1425
1426 /// Set memory descriptor of an operation in terms of dimensions and memory
1427 /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1428 /// but memory::format_tag would be mkldnn::any because we want MKL-DNN to
1429 /// choose the best layout/format for given input dimensions.
SetOpMemDesc(const memory::dims & dim,memory::format_tag fm)1430 inline void SetOpMemDesc(const memory::dims& dim, memory::format_tag fm) {
1431 // TODO(nhasabni): can we remove dynamic memory allocation?
1432 op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1433 }
1434
1435 /// Get function for memory descriptor for an operation
GetOpMemDesc()1436 inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1437
1438 /// Predicate that checks if we need to reorder user's memory into memory
1439 /// pointed by op_md.
1440 ///
1441 /// @input: op_md - memory descriptor of the given input of an operation.
1442 /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::desc & op_pd)1443 inline bool IsReorderNeeded(const memory::desc& op_pd) const {
1444 DCHECK(user_memory_);
1445 return op_pd != user_memory_->get_desc();
1446 }
1447
1448 /// Function to create a reorder from memory pointed by from to memory pointed
1449 /// by to. Returns created primitive.
CreateReorder(const memory * from,const memory * to)1450 inline primitive CreateReorder(const memory* from, const memory* to) const {
1451 CHECK_NOTNULL(from);
1452 CHECK_NOTNULL(to);
1453 return reorder(*from, *to);
1454 }
1455
1456 /// Function to handle input reordering
1457 ///
1458 /// Check if we need to reorder this input of an operation.
1459 /// Return true and allocate reorder memory primitive if reorder is needed.
1460 /// Otherwise, return false and do not allocate reorder memory primitive.
1461 ///
1462 /// To check if reorder is needed, this function compares memory primitive
1463 /// descriptor (memory descriptor for v1.x) of an operation (op_pd) for
1464 /// the given input with the user-specified memory descriptor.
1465 ///
1466 /// @input: op_pd - memory primitive descriptor of the given input of an
1467 /// operation
1468 /// @input: net - net to which to add reorder primitive in case it is needed.
1469 /// @input: net_args - net to which user and reorder memories are added if
1470 /// needed. Each entry is a key-value pair of the form
1471 /// <argument-type, mkldnn::memory>.
1472 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1473 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1474 std::vector<primitive>& net,
1475 std::vector<MemoryArgsMap>& net_args,
1476 const engine& engine) {
1477 DCHECK(user_memory_);
1478 DCHECK_EQ(net.size(), net_args.size());
1479 if (IsReorderNeeded(op_md)) {
1480 // TODO(nhasabni): can we remove dynamic memory allocation?
1481 reorder_memory_ = new memory(op_md, engine);
1482 net.push_back(CreateReorder(user_memory_, reorder_memory_));
1483 net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1484 {MKLDNN_ARG_TO, *reorder_memory_}});
1485 return true;
1486 }
1487 return false;
1488 }
1489
1490 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1491 const engine& engine,
1492 OpKernelContext* context = nullptr) {
1493 DCHECK(user_memory_);
1494 if (IsReorderNeeded(op_md)) {
1495 // TODO(nhasabni): can we remove dynamic memory allocation?
1496 // primitive reuse don't allow two same reorder prim in
1497 // one stream, so submit it immediately
1498 reorder_memory_ = new memory(op_md, engine);
1499 auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1500 std::shared_ptr<stream> cpu_stream;
1501 MklDnnThreadPool eigen_tp;
1502 if (context != nullptr) {
1503 eigen_tp = MklDnnThreadPool(context);
1504 cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1505 } else {
1506 cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1507 }
1508 std::vector<primitive> net;
1509 net.push_back(*(prim->GetPrimitive()));
1510 std::vector<MemoryArgsMap> net_args;
1511 net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1512 {MKLDNN_ARG_TO, *reorder_memory_}});
1513 execute_primitives(net, cpu_stream, net_args);
1514 return true;
1515 }
1516 return false;
1517 }
1518
1519 /// Overloaded version of above function that accepts memory buffer
1520 /// where output of reorder needs to be stored.
1521 ///
1522 /// @input: op_pd - memory primitive descriptor (memory descriptor for v1.x)
1523 /// of the given input of an operation
1524 /// @reorder_data_handle - memory buffer where output of reorder needs to be
1525 /// stored. Primitive does not check if buffer has
1526 /// enough size to write.
1527 /// @input: net - net to which to add reorder primitive in case it is needed.
1528 /// @input: net_args - net to which user and reorder memories are added if
1529 /// needed. Each entry is a key-value pair of the form
1530 /// <argument-type, mkldnn::memory>.
1531 /// @input: engine - MKL-DNN's abstraction of a computational device
1532 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,void * reorder_data_handle,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1533 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1534 void* reorder_data_handle,
1535 std::vector<primitive>& net,
1536 std::vector<MemoryArgsMap>& net_args,
1537 const engine& engine) {
1538 DCHECK(reorder_data_handle);
1539 DCHECK(user_memory_);
1540 if (IsReorderNeeded(op_md)) {
1541 // TODO(nhasabni): can we remove dynamic memory allocation?
1542 reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1543 net.push_back(CreateReorder(user_memory_, reorder_memory_));
1544 net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1545 {MKLDNN_ARG_TO, *reorder_memory_}});
1546 return true;
1547 }
1548 return false;
1549 }
1550
1551 /// This is a faster path with reorder primitive cache compared with
1552 /// CheckReorderToOpMem(..., std::vector<primitive>* net).
1553 /// The slower path will be removed in the future
1554 /// TODO(bhavanis): Need to use reorder cache here for better performance.
1555 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1556 void* reorder_data_handle,
1557 const engine& engine,
1558 OpKernelContext* context = nullptr) {
1559 DCHECK(reorder_data_handle);
1560 DCHECK(user_memory_);
1561 if (IsReorderNeeded(op_md)) {
1562 // TODO(nhasabni): can we remove dynamic memory allocation?
1563 // primitive reuse don't allow two same reorder prim in
1564 // one stream, so submit it immediately
1565 reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1566 auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1567 std::shared_ptr<stream> cpu_stream;
1568 MklDnnThreadPool eigen_tp;
1569 if (context != nullptr) {
1570 eigen_tp = MklDnnThreadPool(context);
1571 cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1572 } else {
1573 cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1574 }
1575 std::vector<primitive> net;
1576 net.push_back(*(prim->GetPrimitive()));
1577 std::vector<MemoryArgsMap> net_args;
1578 net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1579 {MKLDNN_ARG_TO, *reorder_memory_}});
1580 execute_primitives(net, cpu_stream, net_args);
1581 return true;
1582 }
1583 return false;
1584 }
1585
1586 /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1587 /// where output of reorder needs to be stored.
1588 ///
1589 /// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
1590 /// of the given input of an operation
1591 /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1592 /// reorder. Primitive does not check if buffer is
1593 /// enough size to write.
1594 /// @input: net - net to which to add reorder primitive in case it is needed.
1595 /// @input: net_args - net to which user and reorder memories are added if
1596 /// needed. Each entry is a key-value pair of the form
1597 /// <argument-type, mkldnn::memory>.
1598 /// @input: engine - MKL-DNN's abstraction of a computational device
1599 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,Tensor * reorder_tensor,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1600 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1601 Tensor* reorder_tensor,
1602 std::vector<primitive>& net,
1603 std::vector<MemoryArgsMap>& net_args,
1604 const engine& engine) {
1605 DCHECK(reorder_tensor);
1606 return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
1607 net_args, engine);
1608 }
1609
1610 /// TODO: this is a faster path with reorder primitive cache compared with
1611 /// CheckReorderToOpMem(op_md, reorder_tensor, net, net_args, engine), will
1612 /// remove
1613 /// slow path in the future
1614 inline bool CheckReorderToOpMem(const memory::desc& op_pd,
1615 Tensor* reorder_tensor,
1616 OpKernelContext* ctx = nullptr) {
1617 DCHECK(reorder_tensor);
1618 return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
1619 *cpu_engine_, ctx);
1620 }
1621
1622 /// Function to handle output reorder
1623 ///
1624 /// This function performs very similar functionality as input reordering
1625 /// function above. The only difference is that this function does not add
1626 /// reorder primitive to the net. The reason for this is: the reorder
1627 /// primitive for output needs to be added to the list only after operation
1628 /// has executed. But we need to prepare a temporary buffer in case output
1629 /// reorder is needed. And this temporary buffer will hold the output of
1630 /// an operation before it is fed to reorder primitive.
1631 ///
1632 /// @input - memory primitive descriptor (memory descriptor for v1.x) for the
1633 /// given output of an operation
1634 /// @return: true in case reorder of output is needed; false, otherwise.
PrepareReorderToUserMemIfReq(const memory::desc & op_pd)1635 inline bool PrepareReorderToUserMemIfReq(const memory::desc& op_pd) {
1636 DCHECK(user_memory_);
1637 if (IsReorderNeeded(op_pd)) {
1638 // TODO(nhasabni): can we remove dynamic memory allocation?
1639 reorder_memory_ = new memory(op_pd, *cpu_engine_);
1640 return true;
1641 }
1642 return false;
1643 }
1644
1645 /// Function to actually insert reorder primitive in the net
1646 ///
1647 /// This function completes remaining part of output reordering. It inserts
1648 /// a reordering primitive from the temporary buffer that holds the output
1649 /// to the user-specified output buffer.
1650 ///
1651 /// @input: net - net to which to add reorder primitive
1652 /// @input: net_args - net to which user and reorder memories are added if
1653 /// needed. Each entry is a key-value pair of the form
1654 /// <argument-type, mkldnn::memory>.
InsertReorderToUserMem(std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args)1655 inline void InsertReorderToUserMem(std::vector<primitive>& net,
1656 std::vector<MemoryArgsMap>& net_args) {
1657 DCHECK(user_memory_);
1658 DCHECK(reorder_memory_);
1659 net.push_back(CreateReorder(reorder_memory_, user_memory_));
1660 net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
1661 {MKLDNN_ARG_TO, *user_memory_}});
1662 }
1663
1664 /// TODO: this is a faster path with reorder primitive cache compared with
1665 /// InsertReorderToUserMem(net, net_args), will remove
1666 /// slow path in the future
1667 inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
1668 DCHECK(user_memory_);
1669 DCHECK(reorder_memory_);
1670 DCHECK(cpu_engine_);
1671 // primitive reuse don't allow two same reorder prim in
1672 // one stream, so submit it immediately
1673 std::vector<primitive> net;
1674 auto* prim = FindOrCreateReorder<T>(reorder_memory_, user_memory_);
1675 net.push_back(*(prim->GetPrimitive()));
1676 std::vector<MemoryArgsMap> net_args;
1677 net_args.push_back(
1678 {{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
1679 std::shared_ptr<stream> cpu_stream;
1680 MklDnnThreadPool eigen_tp;
1681 if (ctx != nullptr) {
1682 eigen_tp = MklDnnThreadPool(ctx);
1683 cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1684 } else {
1685 cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1686 }
1687 execute_primitives(net, cpu_stream, net_args);
1688 }
1689 };
1690
1691 /// Base class for operations with reuse of primitives
1692 class MklPrimitive {
1693 public:
~MklPrimitive()1694 virtual ~MklPrimitive() {}
MklPrimitive()1695 MklPrimitive() {}
MklPrimitive(const engine & cpu_engine)1696 MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
1697 // Dummy data which MKL DNN never operates on
1698 unsigned char* DummyData = nullptr;
1699 engine cpu_engine_ = engine(engine::kind::cpu, 0);
GetEngine()1700 const engine& GetEngine() { return cpu_engine_; }
1701 };
1702
1703 const mkldnn::memory::dims NONE_DIMS = {};
1704
1705 //
1706 // LRUCache is a class which implements LRU (Least Recently Used) cache.
1707 // The implementation is similar to that of
1708 // tensorflow/core/platform/cloud/expiring_lru_cache.h
1709 // without its thread-safe part because the cache is supposed to be
1710 // used as thread local (for instance, MklPrimitive caching).
1711 //
1712 // The LRU list maintains objects in chronological order based on
1713 // creation time, with the least recently accessed object at the
1714 // tail of LRU list, while the most recently accessed object
1715 // at the head of LRU list.
1716 //
1717 // This class is used to maintain an upper bound on the total number of
1718 // cached items. When the cache reaches its capacity, the LRU item will
1719 // be removed and replaced by a new one from SetOp call.
1720 //
1721 template <typename T>
1722 class LRUCache {
1723 public:
LRUCache(size_t capacity)1724 explicit LRUCache(size_t capacity) {
1725 capacity_ = capacity;
1726 Clear();
1727 }
1728
GetOp(const string & key)1729 T* GetOp(const string& key) {
1730 auto it = cache_.find(key);
1731 if (it == cache_.end()) {
1732 return nullptr;
1733 }
1734
1735 // Move to the front of LRU list as the most recently accessed.
1736 lru_list_.erase(it->second.lru_iterator);
1737 lru_list_.push_front(it->first);
1738 it->second.lru_iterator = lru_list_.begin();
1739 return it->second.op;
1740 }
1741
SetOp(const string & key,T * op)1742 void SetOp(const string& key, T* op) {
1743 if (lru_list_.size() >= capacity_) {
1744 Delete();
1745 }
1746
1747 // Insert an entry to the front of the LRU list
1748 lru_list_.push_front(key);
1749 Entry entry(op, lru_list_.begin());
1750 cache_.emplace(std::make_pair(key, std::move(entry)));
1751 }
1752
Clear()1753 void Clear() {
1754 if (lru_list_.empty()) return;
1755
1756 // Clean up the cache
1757 cache_.clear();
1758 lru_list_.clear();
1759 }
1760
1761 private:
1762 struct Entry {
1763 // The entry's value.
1764 T* op;
1765
1766 // A list iterator pointing to the entry's position in the LRU list.
1767 std::list<string>::iterator lru_iterator;
1768
1769 // Constructor
EntryEntry1770 Entry(T* op, std::list<string>::iterator it) {
1771 this->op = op;
1772 this->lru_iterator = it;
1773 }
1774
1775 // Move constructor
EntryEntry1776 Entry(Entry&& source) noexcept
1777 : lru_iterator(std::move(source.lru_iterator)) {
1778 op = std::move(source.op);
1779 source.op = std::forward<T*>(nullptr);
1780 }
1781
1782 // Destructor
~EntryEntry1783 ~Entry() {
1784 if (op != nullptr) delete op;
1785 }
1786 };
1787
1788 // Remove the least recently accessed entry from LRU list, which
1789 // is the tail of lru_list_. Update cache_ correspondingly.
Delete()1790 bool Delete() {
1791 if (lru_list_.empty()) return false;
1792 string key = lru_list_.back();
1793 lru_list_.pop_back();
1794 cache_.erase(key);
1795 return true;
1796 }
1797
1798 // Cache capacity
1799 size_t capacity_;
1800
1801 // The cache, a map from string key to a LRU entry.
1802 std::unordered_map<string, Entry> cache_;
1803
1804 // The LRU list of entries.
1805 // The front of the list contains the key of the most recently accessed
1806 // entry, while the back of the list is the least recently accessed entry.
1807 std::list<string> lru_list_;
1808 };
1809
1810 template <typename T>
1811 class MklPrimitiveFactory {
1812 public:
MklPrimitiveFactory()1813 MklPrimitiveFactory() {}
1814
~MklPrimitiveFactory()1815 ~MklPrimitiveFactory() {}
1816
GetOp(const string & key)1817 MklPrimitive* GetOp(const string& key) {
1818 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1819 return lru_cache.GetOp(key);
1820 }
1821
SetOp(const string & key,MklPrimitive * op)1822 void SetOp(const string& key, MklPrimitive* op) {
1823 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1824 lru_cache.SetOp(key, op);
1825 }
1826
1827 /// Function to decide whether HW has AVX512 or AVX2
1828 /// For those legacy device(w/o AVX512 and AVX2),
1829 /// MKL-DNN GEMM will be used.
IsLegacyPlatform()1830 static inline bool IsLegacyPlatform() {
1831 static const bool is_legacy_platform =
1832 (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
1833 !port::TestCPUFeature(port::CPUFeature::AVX2));
1834 return is_legacy_platform;
1835 }
1836
1837 /// Function to check whether primitive memory optimization is enabled
IsPrimitiveMemOptEnabled()1838 static inline bool IsPrimitiveMemOptEnabled() {
1839 static const bool is_primitive_mem_opt_enabled = [] {
1840 bool value = true;
1841 TF_CHECK_OK(
1842 ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true, &value));
1843 return value;
1844 }();
1845 return is_primitive_mem_opt_enabled;
1846 }
1847
1848 private:
GetLRUCache()1849 static inline LRUCache<MklPrimitive>& GetLRUCache() {
1850 static const int kCapacity = 1024; // cache capacity
1851 static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
1852 return lru_cache_;
1853 }
1854 };
1855
1856 // utility class for creating keys of MKL primitive pool.
1857 class FactoryKeyCreator {
1858 public:
FactoryKeyCreator()1859 FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
1860
~FactoryKeyCreator()1861 ~FactoryKeyCreator() {}
1862
AddAsKey(const string & str)1863 void AddAsKey(const string& str) { Append(str); }
1864
AddAsKey(const mkldnn::memory::dims & dims)1865 void AddAsKey(const mkldnn::memory::dims& dims) {
1866 for (unsigned int i = 0; i < dims.size(); i++) {
1867 AddAsKey<int>(dims[i]);
1868 }
1869 }
1870
1871 template <typename T>
AddAsKey(const T data)1872 void AddAsKey(const T data) {
1873 auto buffer = reinterpret_cast<const char*>(&data);
1874 Append(StringPiece(buffer, sizeof(T)));
1875 }
1876
1877 // generalisation to handle pointers
AddAsKey(const void * data)1878 void AddAsKey(const void* data) {
1879 auto buffer = reinterpret_cast<const char*>(&data);
1880 Append(StringPiece(buffer, sizeof(data)));
1881 }
1882
GetKey()1883 string GetKey() { return key_; }
1884
1885 private:
1886 string key_;
1887 const char delimiter = 'x';
1888 const int kMaxKeyLength = 256;
Append(StringPiece s)1889 void Append(StringPiece s) {
1890 key_.append(string(s));
1891 key_.append(1, delimiter);
1892 }
1893 };
1894
1895 class MklReorderPrimitive : public MklPrimitive {
1896 public:
MklReorderPrimitive(const memory * from,const memory * to)1897 explicit MklReorderPrimitive(const memory* from, const memory* to)
1898 : MklPrimitive(engine(engine::kind::cpu, 0)) {
1899 Setup(from, to);
1900 }
~MklReorderPrimitive()1901 ~MklReorderPrimitive() {}
1902
GetPrimitive()1903 std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
1904
SetMemory(const memory * from,const memory * to)1905 void SetMemory(const memory* from, const memory* to) {
1906 context_.src_mem->set_data_handle(from->get_data_handle());
1907 context_.dst_mem->set_data_handle(to->get_data_handle());
1908 }
1909
GetStream()1910 std::shared_ptr<mkldnn::stream> GetStream() { return stream_; }
1911
1912 private:
1913 struct ReorderContext {
1914 std::shared_ptr<mkldnn::memory> src_mem;
1915 std::shared_ptr<mkldnn::memory> dst_mem;
1916 std::shared_ptr<primitive> reorder_prim;
ReorderContextReorderContext1917 ReorderContext()
1918 : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
1919 } context_;
1920
1921 std::shared_ptr<mkldnn::stream> stream_;
1922
Setup(const memory * from,const memory * to)1923 void Setup(const memory* from, const memory* to) {
1924 context_.src_mem.reset(
1925 new memory(from->get_desc(), cpu_engine_, DummyData));
1926 context_.dst_mem.reset(new memory(to->get_desc(), cpu_engine_, DummyData));
1927 context_.reorder_prim = std::make_shared<mkldnn::reorder>(
1928 reorder(*context_.src_mem, *context_.dst_mem));
1929 stream_.reset(new stream(cpu_engine_));
1930 }
1931 };
1932
1933 template <typename T>
1934 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
1935 public:
Get(const memory * from,const memory * to)1936 static MklReorderPrimitive* Get(const memory* from, const memory* to) {
1937 auto reorderPrim = static_cast<MklReorderPrimitive*>(
1938 MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
1939 if (reorderPrim == nullptr) {
1940 reorderPrim = new MklReorderPrimitive(from, to);
1941 MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
1942 reorderPrim);
1943 }
1944 reorderPrim->SetMemory(from, to);
1945 return reorderPrim;
1946 }
1947
GetInstance()1948 static MklReorderPrimitiveFactory& GetInstance() {
1949 static MklReorderPrimitiveFactory instance_;
1950 return instance_;
1951 }
1952
CreateKey(const memory * from,const memory * to)1953 static string CreateKey(const memory* from, const memory* to) {
1954 string prefix = "reorder";
1955 FactoryKeyCreator key_creator;
1956 auto const& from_desc = from->get_desc().data;
1957 auto const& to_desc = to->get_desc().data;
1958 memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
1959 memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
1960 auto from_strides = from_desc.format_desc.blocking.strides;
1961
1962 // As DNNL memory desc has C style array and only init the used
1963 // part, so need use the valid part as key.
1964 auto from_inner_nblks = from_desc.format_desc.blocking.inner_nblks;
1965 auto from_inner_blks = from_desc.format_desc.blocking.inner_blks;
1966 auto from_inner_idxs = from_desc.format_desc.blocking.inner_idxs;
1967 memory::dims from_inner_blks_1(from_inner_blks,
1968 &from_inner_blks[from_inner_nblks]);
1969 memory::dims from_inner_idxs_1(from_inner_idxs,
1970 &from_inner_idxs[from_inner_nblks]);
1971 auto to_inner_nblks = to_desc.format_desc.blocking.inner_nblks;
1972 auto to_inner_blks = to_desc.format_desc.blocking.inner_blks;
1973 auto to_inner_idxs = to_desc.format_desc.blocking.inner_idxs;
1974 memory::dims to_inner_blks_1(to_inner_blks, &to_inner_blks[to_inner_nblks]);
1975 memory::dims to_inner_idxs_1(to_inner_idxs, &to_inner_idxs[to_inner_nblks]);
1976
1977 auto to_strides = to_desc.format_desc.blocking.strides;
1978 memory::dims from_strides_outer_blocks(from_strides,
1979 &from_strides[from_desc.ndims]);
1980 memory::dims to_strides_outer_blocks(to_strides,
1981 &to_strides[to_desc.ndims]);
1982
1983 key_creator.AddAsKey(prefix);
1984 key_creator.AddAsKey(static_cast<int>(from_desc.extra.flags));
1985 key_creator.AddAsKey(static_cast<int>(from_inner_nblks));
1986 key_creator.AddAsKey(from_inner_blks_1);
1987 key_creator.AddAsKey(from_inner_idxs_1);
1988 key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
1989 key_creator.AddAsKey(from_dims);
1990 key_creator.AddAsKey(from_strides_outer_blocks);
1991 key_creator.AddAsKey(static_cast<int>(to_desc.extra.flags));
1992 key_creator.AddAsKey(static_cast<int>(to_inner_nblks));
1993 key_creator.AddAsKey(to_inner_blks_1);
1994 key_creator.AddAsKey(to_inner_idxs_1);
1995 key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
1996 key_creator.AddAsKey(to_dims);
1997 key_creator.AddAsKey(to_strides_outer_blocks);
1998 return key_creator.GetKey();
1999 }
2000
2001 private:
MklReorderPrimitiveFactory()2002 MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory()2003 ~MklReorderPrimitiveFactory() {}
2004
GetReorder(const memory * from,const memory * to)2005 MklPrimitive* GetReorder(const memory* from, const memory* to) {
2006 string key = CreateKey(from, to);
2007 return this->GetOp(key);
2008 }
2009
SetReorder(const memory * from,const memory * to,MklPrimitive * op)2010 void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
2011 string key = CreateKey(from, to);
2012 this->SetOp(key, op);
2013 }
2014 };
2015
2016 /// Function to find(or create) a reorder from memory pointed by
2017 /// from to memory pointed by to, it will created primitive or
2018 /// get primitive from pool if it is cached.
2019 /// Returns the primitive.
2020 template <typename T>
FindOrCreateReorder(const memory * from,const memory * to)2021 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
2022 const memory* to) {
2023 CHECK_NOTNULL(from);
2024 CHECK_NOTNULL(to);
2025 MklReorderPrimitive* reorder_prim =
2026 MklReorderPrimitiveFactory<T>::Get(from, to);
2027 return reorder_prim;
2028 }
2029
2030 // utility function to determine if it is conv 1x1 and stride != 1
2031 // for purpose of temporarily disabling primitive reuse
IsConv1x1StrideNot1(memory::dims filter_dims,memory::dims strides)2032 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2033 memory::dims strides) {
2034 if (filter_dims.size() != 4 || strides.size() != 2) return false;
2035
2036 return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2037 ((strides[0] != 1) || (strides[1] != 1)));
2038 }
2039
2040 } // namespace tensorflow
2041
2042 /////////////////////////////////////////////////////////////////////
2043 // Macros for handling registration for various types
2044 /////////////////////////////////////////////////////////////////////
2045
2046 #define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
2047
2048 #define REGISTER_TEST_BFLOAT16(TEST) \
2049 REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
2050
2051 #define REGISTER_TEST_ALL_TYPES(TEST) \
2052 REGISTER_TEST_FLOAT32(TEST); \
2053 REGISTER_TEST_BFLOAT16(TEST);
2054 #else
2055 #define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
2056
2057 #endif // INTEL_MKL
2058 #endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2059