1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18 #ifdef INTEL_MKL
19
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26
27 #include "mkldnn.hpp"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/graph/mkl_graph_util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/util/env_var.h"
39 #include "tensorflow/core/util/mkl_threadpool.h"
40 #include "tensorflow/core/util/mkl_types.h"
41 #include "tensorflow/core/util/padding.h"
42 #include "tensorflow/core/util/tensor_format.h"
43
44 using mkldnn::engine;
45 using mkldnn::memory;
46 using mkldnn::primitive;
47 using mkldnn::reorder;
48 using mkldnn::stream;
49 using CPUDevice = Eigen::ThreadPoolDevice;
50 using MemoryArgsMap = std::unordered_map<int, memory>;
51 using ReorderPd = mkldnn::reorder::primitive_desc;
52
53 #ifdef _WIN32
54 typedef unsigned int uint;
55 #endif
56
57 namespace tensorflow {
58
59 // The file contains a number of utility classes and functions used by MKL
60 // enabled kernels
61
62 // This class encapsulates all the meta data that is associated with an MKL
63 // tensor. A tensor is an MKL tensor if it was created as the result of an
64 // MKL operation, and did not go through a conversion to a standard
65 // Tensorflow tensor.
66
67 // The dimensions order that MKL-DNN internally uses for 2D activations
68 // [Batch, Channel, Height, Width] and
69 // for 2D filters [Out_Channel, In_Channel, Height, Width].
70 typedef enum {
71 Dim_N = 0,
72 Dim_C = 1,
73 Dim_H = 2,
74 Dim_W = 3,
75 Dim_O = 0,
76 Dim_I = 1
77 } MklDnnDims;
78
79 // The dimensions order that MKL-DNN internally uses for 3D activations
80 // [Batch, Channel, Depth, Height, Width] and
81 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
82 typedef enum {
83 Dim3d_N = 0,
84 Dim3d_C = 1,
85 Dim3d_D = 2,
86 Dim3d_H = 3,
87 Dim3d_W = 4,
88 Dim3d_O = 0,
89 Dim3d_I = 1
90 } MklDnnDims3D;
91
92 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
93 // filter_width, in_channels, out_channels]
94 typedef enum {
95 TF_2DFILTER_DIM_H = 0,
96 TF_2DFILTER_DIM_W = 1,
97 TF_2DFILTER_DIM_I = 2,
98 TF_2DFILTER_DIM_O = 3
99 } TFFilterDims2d;
100
101 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
102 // filter_height, filter_width, in_channels, out_channels]
103 typedef enum {
104 TF_3DFILTER_DIM_P = 0,
105 TF_3DFILTER_DIM_H = 1,
106 TF_3DFILTER_DIM_W = 2,
107 TF_3DFILTER_DIM_I = 3,
108 TF_3DFILTER_DIM_O = 4
109 } TFFilterDims3d;
110
111 // The dimensions order that MKL-DNN requires for the filter in a grouped
112 // convolution (2D only)
113 typedef enum {
114 MKL_GROUP_FILTER_DIM_G = 0,
115 MKL_GROUP_FILTER_DIM_O = 1,
116 MKL_GROUP_FILTER_DIM_I = 2,
117 MKL_GROUP_FILTER_DIM_H = 3,
118 MKL_GROUP_FILTER_DIM_W = 4
119 } MklDnnFilterGroupDims;
120
121 // Enum used to templatize MklOp kernel implementation
122 // that support both fp32 and int8 versions.
123 enum class MklQuantization {
124 QUANTIZED_VERSION,
125 FP_VERSION,
126 };
127
128 static const int kSmallBatchSize = 32;
129
execute_primitives(std::vector<mkldnn::primitive> & primitives,std::shared_ptr<stream> stream,std::vector<std::unordered_map<int,memory>> & net_args)130 inline void execute_primitives(
131 std::vector<mkldnn::primitive>& primitives, std::shared_ptr<stream> stream,
132 std::vector<std::unordered_map<int, memory>>& net_args) {
133 DCHECK_EQ(primitives.size(), net_args.size());
134 for (size_t i = 0; i < primitives.size(); ++i) {
135 primitives.at(i).execute(*stream, net_args.at(i));
136 }
137 }
138
139 // In MKL-DNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor
140 // (md) structure will no longer be recorded in its `format` field. Instead, it
141 // will be set to a canonical `blocked` format for every fully described md.
142 //
143 // Currently, we query this `format` field while mapping MKL-DNN's data format
144 // to TF's data format. Due to the above restriction, we will now get this data
145 // format information from TF's `data_format` attribute (i.e. via
146 // `TensorFormat`) for MKL-DNN v1.x.
147 //
148 // Some MKL-DNN operators such as ReLU do not have a `data_format` attribute
149 // since they are usually in `blocked` format. Therefore, in order to
150 // distinguish between blocked and non-blocked formats, we have defined a new
151 // enum called `MklTensorFormat` that is semantically similar to `TensorFormat`
152 // but with the following additional fields namely:
153 // 1) FORMAT_BLOCKED: as described above, this is needed for element-wise
154 // operators such as ReLU.
155 // 2) FORMAT_INVALID: for error-checking (ex. unsupported format)
156 // 3) FORMAT_X, FORMAT_NC, FORMAT_TNC: to distinguish between MKL tensors based
157 // on their dimensions in operators such as Softmax, i.e.:
158 // FORMAT_X - 1D tensor
159 // FORMAT_NC - 2D tensor
160 // FORMAT_TNC - 3D tensor
161 enum class MklTensorFormat {
162 FORMAT_NHWC = 0,
163 FORMAT_NCHW = 1,
164 FORMAT_NDHWC = 2,
165 FORMAT_NCDHW = 3,
166 FORMAT_X = 4,
167 FORMAT_NC = 5,
168 FORMAT_TNC = 6,
169 FORMAT_BLOCKED = 7,
170 FORMAT_INVALID = 8,
171 };
172
173 // Forward declarations
174 memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
175
176 TensorFormat MklDnn3DDataFormatToTFDataFormat(MklTensorFormat format);
177 TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format);
178
179 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
180 Status CreateBlockedMemDescHelper(const memory::dims& dim,
181 const memory::dims& strides,
182 memory::data_type dtype,
183 mkldnn_memory_desc_t* blocked_md);
184
185 inline std::ostream& operator<<(std::ostream& os,
186 const memory::format_tag& tag) {
187 if (tag == memory::format_tag::undef) {
188 os << "undef";
189 } else if (tag == memory::format_tag::any) {
190 os << "any";
191 } else {
192 os << "invalid";
193 }
194 return os;
195 }
196
197 inline void operator<<(std::ostream& os, const MklTensorFormat& format) {
198 if (format == MklTensorFormat::FORMAT_NHWC) {
199 os << "FORMAT_NHWC";
200 } else if (format == MklTensorFormat::FORMAT_NCHW) {
201 os << "FORMAT_NCHW";
202 } else if (format == MklTensorFormat::FORMAT_NDHWC) {
203 os << "FORMAT_NDHWC";
204 } else if (format == MklTensorFormat::FORMAT_NCDHW) {
205 os << "FORMAT_NCDHW";
206 } else if (format == MklTensorFormat::FORMAT_X) {
207 os << "FORMAT_X";
208 } else if (format == MklTensorFormat::FORMAT_NC) {
209 os << "FORMAT_NC";
210 } else if (format == MklTensorFormat::FORMAT_TNC) {
211 os << "FORMAT_TNC";
212 } else if (format == MklTensorFormat::FORMAT_BLOCKED) {
213 os << "FORMAT_BLOCKED";
214 } else {
215 os << "INVALID FORMAT";
216 }
217 }
218
219 template <typename T>
array_cmp(const T * a1,const T * a2,size_t size)220 inline bool array_cmp(const T* a1, const T* a2, size_t size) {
221 for (size_t i = 0; i < size; ++i)
222 if (a1[i] != a2[i]) return false;
223 return true;
224 }
225
CreateStream(OpKernelContext * ctx,const engine & engine)226 inline mkldnn::stream* CreateStream(OpKernelContext* ctx,
227 const engine& engine) {
228 #ifdef ENABLE_MKLDNN_THREADPOOL
229 stream_attr tp_stream_attr(engine::kind::cpu);
230 if (ctx != nullptr) {
231 auto eigen_tp =
232 MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);
233 tp_stream_attr.set_threadpool(eigen_tp);
234 stream* tp_stream =
235 new stream(engine, stream::flags::default_flags, tp_stream_attr);
236 return tp_stream;
237 } else {
238 stream* tp_stream = new stream(engine);
239 return tp_stream;
240 }
241 #else
242 stream* tp_stream = new stream(engine);
243 return tp_stream;
244 #endif // ENABLE_MKLDNN_THREADPOOL
245 }
246
247 class MklDnnShape {
248 private:
249 struct MklShapeData {
250 // Flag to indicate if the tensor is an MKL tensor or not
251 bool is_mkl_tensor_ = false;
252 // Number of dimensions in Tensorflow format
253 size_t dimension_ = 0;
254 mkldnn_dims_t sizes_; // Required by MKL for conversions
255 MklTensorFormat tf_data_format_ = MklTensorFormat::FORMAT_BLOCKED;
256 memory::data_type T_ = memory::data_type::undef;
257 // MKL layout
258 mkldnn_memory_desc_t mkl_md_;
259 /// TF dimension corresponding to this MKL dimension
260 mkldnn_dims_t map_;
261 };
262 MklShapeData data_;
263
264 typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
265
266 #define INVALID_DIM_SIZE -1
267
268 public:
MklDnnShape()269 MklDnnShape() {
270 for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
271 ++i) {
272 data_.sizes_[i] = -1;
273 }
274 for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
275 data_.map_[i] = -1;
276 }
277 }
278
~MklDnnShape()279 ~MklDnnShape() {}
280 TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape); // Cannot copy
281
282 /// Equality function for MklDnnShape objects
283 /// @return true if both are equal; false otherwise.
284 inline bool operator==(const MklDnnShape& input_shape) const {
285 if (this->IsMklTensor() != input_shape.IsMklTensor()) {
286 return false;
287 }
288
289 // If input tensors are in MKL layout, then we check for dimensions and
290 // sizes.
291 if (this->IsMklTensor()) {
292 const mkldnn_memory_desc_t& cur_md = (this->GetMklLayout()).data;
293 const mkldnn_memory_desc_t& input_shape_md =
294 input_shape.GetMklLayout().data;
295 return this->GetTfShape() == input_shape.GetTfShape() &&
296 mkldnn_memory_desc_equal(&cur_md, &input_shape_md);
297 }
298
299 // Both inputs are not MKL tensors.
300 return true;
301 }
302
303 /// Equality operator for MklDnnShape and TFShape.
304 /// Returns: true if TF shapes for both are the same, false otherwise
305 inline bool operator==(const TensorShape& input_shape) const {
306 if (!this->IsMklTensor()) {
307 return false;
308 }
309
310 return this->GetTfShape() == input_shape;
311 }
312
IsMklTensor()313 inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
SetMklTensor(bool is_mkl_tensor)314 inline void SetMklTensor(bool is_mkl_tensor) {
315 data_.is_mkl_tensor_ = is_mkl_tensor;
316 }
317
SetDimensions(const size_t dimension)318 inline void SetDimensions(const size_t dimension) {
319 data_.dimension_ = dimension;
320 }
GetDimension(char dimension)321 inline size_t GetDimension(char dimension) const {
322 int index = GetMklDnnTensorDimIndex(dimension);
323 CHECK(index >= 0 && index < this->GetDimension())
324 << "Invalid index from the dimension: " << index << ", " << dimension;
325 return this->DimSize(index);
326 }
327
GetDimension3D(char dimension)328 inline size_t GetDimension3D(char dimension) const {
329 int index = GetMklDnnTensor3DDimIndex(dimension);
330 CHECK(index >= 0 && index < this->GetDimension())
331 << "Invalid index from the dimension: " << index << ", " << dimension;
332 return this->DimSize(index);
333 }
334
GetMklDnnTensorDimIndex(char dimension)335 inline int32 GetMklDnnTensorDimIndex(char dimension) const {
336 switch (dimension) {
337 case 'N':
338 return MklDnnDims::Dim_N;
339 case 'C':
340 return MklDnnDims::Dim_C;
341 case 'H':
342 return MklDnnDims::Dim_H;
343 case 'W':
344 return MklDnnDims::Dim_W;
345 default:
346 LOG(FATAL) << "Invalid dimension: " << dimension;
347 return -1; // Avoid compiler warning about missing return value
348 }
349 }
350
GetMklDnnTensor3DDimIndex(char dimension)351 inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
352 switch (dimension) {
353 case 'N':
354 return MklDnnDims3D::Dim3d_N;
355 case 'C':
356 return MklDnnDims3D::Dim3d_C;
357 case 'D':
358 return MklDnnDims3D::Dim3d_D;
359 case 'H':
360 return MklDnnDims3D::Dim3d_H;
361 case 'W':
362 return MklDnnDims3D::Dim3d_W;
363 default:
364 LOG(FATAL) << "Invalid dimension: " << dimension;
365 return -1; // Avoid compiler warning about missing return value
366 }
367 }
368
GetDimension()369 inline size_t GetDimension() const { return data_.dimension_; }
GetSizes()370 inline const int* GetSizes() const {
371 return reinterpret_cast<const int*>(&data_.sizes_[0]);
372 }
373
374 // Returns an mkldnn::memory::dims object that contains the sizes of this
375 // MklDnnShape object.
GetSizesAsMklDnnDims()376 inline memory::dims GetSizesAsMklDnnDims() const {
377 memory::dims retVal;
378 if (data_.is_mkl_tensor_) {
379 size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
380 for (size_t i = 0; i < dimensions; i++) {
381 if (data_.sizes_[i] != INVALID_DIM_SIZE)
382 retVal.push_back(data_.sizes_[i]);
383 }
384 } else {
385 CHECK_EQ(data_.is_mkl_tensor_, true);
386 }
387 return retVal;
388 }
389
DimSize(int index)390 inline int64 DimSize(int index) const {
391 CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
392 return data_.sizes_[index];
393 }
394
395 /// Return TensorShape that describes the Tensorflow shape of the tensor
396 /// represented by this MklShape.
GetTfShape()397 inline TensorShape GetTfShape() const {
398 CHECK_EQ(data_.is_mkl_tensor_, true);
399
400 std::vector<int32> shape(data_.dimension_, -1);
401 // As mentioned in the comment above, we now rely on TF's `data_format`
402 // attribute to determine if TF shape is in blocked format or not.
403 if (data_.tf_data_format_ != MklTensorFormat::FORMAT_BLOCKED) {
404 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
405 shape[idx] = data_.sizes_[TfDimIdx(idx)];
406 }
407 } else {
408 // If Tensorflow shape is in Blocked format, then we don't have dimension
409 // map for it. So we just create Tensorflow shape from sizes in the
410 // specified order.
411 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
412 shape[idx] = data_.sizes_[idx];
413 }
414 }
415
416 TensorShape ts;
417 bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
418 CHECK_EQ(ret, true);
419 return ts;
420 }
421
SetElemType(memory::data_type dt)422 inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
GetElemType()423 inline const memory::data_type GetElemType() { return data_.T_; }
424
SetMklLayout(memory::desc * md)425 inline void SetMklLayout(memory::desc* md) {
426 CHECK_NOTNULL(md);
427 data_.mkl_md_ = md->data;
428 }
429
GetMklLayout()430 inline const memory::desc GetMklLayout() const {
431 return memory::desc(data_.mkl_md_);
432 }
433
GetTfDataFormat()434 inline MklTensorFormat GetTfDataFormat() const {
435 return data_.tf_data_format_;
436 }
437
438 /// We don't create primitive_descriptor for TensorFlow layout now.
439 /// We use lazy evaluation and create it only when needed. Input format can
440 /// also be Blocked format.
SetTfLayout(size_t dims,const memory::dims & sizes,MklTensorFormat format)441 inline void SetTfLayout(size_t dims, const memory::dims& sizes,
442 MklTensorFormat format) {
443 DCHECK_EQ(dims, sizes.size())
444 << "SetTfLayout: Number of dimensions does not"
445 "match with dimension array";
446 data_.dimension_ = dims;
447 for (size_t ii = 0; ii < dims; ++ii) {
448 data_.sizes_[ii] = sizes[ii];
449 }
450 data_.tf_data_format_ = format;
451 if (format != MklTensorFormat::FORMAT_BLOCKED) {
452 if (dims == 2) {
453 data_.map_[0] = MklDnnDims::Dim_N;
454 data_.map_[1] = MklDnnDims::Dim_C;
455 } else {
456 SetTfDimOrder(dims, format);
457 }
458 }
459 }
460
GetTfLayout()461 inline const memory::desc GetTfLayout() const {
462 memory::dims dims;
463 for (size_t ii = 0; ii < data_.dimension_; ++ii) {
464 dims.push_back(data_.sizes_[ii]);
465 }
466
467 // Create Blocked memory desc if input TF format was set like that.
468 if (data_.tf_data_format_ == MklTensorFormat::FORMAT_BLOCKED) {
469 auto strides = CalculateTFStrides(dims);
470 mkldnn_memory_desc_t blocked_md;
471 TF_CHECK_OK(
472 CreateBlockedMemDescHelper(dims, strides, data_.T_, &blocked_md));
473 return memory::desc(blocked_md);
474 } else {
475 auto format_tag =
476 MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_);
477 DCHECK_NE(format_tag, memory::format_tag::undef);
478 return memory::desc(dims, data_.T_, format_tag);
479 }
480 }
481
GetCurLayout()482 inline const memory::desc GetCurLayout() const {
483 return IsMklTensor() ? GetMklLayout() : GetTfLayout();
484 }
485
486 // We don't need a case of default dimension order because
487 // when an operator that does not get data_format attribute gets all inputs
488 // in Tensorflow format, it will produce output in Tensorflow format.
SetTfDimOrder(const size_t dimension,const mkldnn_dims_t map)489 inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
490 CHECK(dimension == data_.dimension_);
491 for (size_t ii = 0; ii < dimension; ii++) {
492 data_.map_[ii] = map[ii];
493 }
494 }
495
SetTfDimOrder(const size_t dimension,TensorFormat data_format)496 inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
497 if (dimension == 5) {
498 CHECK(dimension == data_.dimension_);
499 data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
500 MklDnnDims3D::Dim3d_D;
501 data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
502 MklDnnDims3D::Dim3d_H;
503 data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
504 MklDnnDims3D::Dim3d_W;
505 data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
506 MklDnnDims3D::Dim3d_C;
507 data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
508 MklDnnDims3D::Dim3d_N;
509 } else {
510 CHECK_EQ(dimension, 4);
511 CHECK(dimension == data_.dimension_);
512 data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
513 data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
514 data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
515 data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
516 }
517 }
518
SetTfDimOrder(const size_t dimension,MklTensorFormat format)519 inline void SetTfDimOrder(const size_t dimension, MklTensorFormat format) {
520 TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
521 SetTfDimOrder(dimension, data_format);
522 }
523
GetTfToMklDimMap()524 inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
TfDimIdx(int index)525 inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
TfDimSize(int index)526 inline int64 TfDimSize(int index) const {
527 return data_.sizes_[TfDimIdx(index)];
528 }
529
530 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
531 /// corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)532 inline bool IsMklChannelDim(int d) const {
533 return TfDimIdx(d) == MklDnnDims::Dim_C;
534 }
535
536 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
537 /// corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)538 inline bool IsMklBatchDim(int d) const {
539 return TfDimIdx(d) == MklDnnDims::Dim_N;
540 }
541
542 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
543 /// corresponds to MKL's Width dimension.
IsMklWidthDim(int d)544 inline bool IsMklWidthDim(int d) const {
545 return TfDimIdx(d) == MklDnnDims::Dim_W;
546 }
547 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
548 /// corresponds to MKL's Height dimension.
IsMklHeightDim(int d)549 inline bool IsMklHeightDim(int d) const {
550 return TfDimIdx(d) == MklDnnDims::Dim_H;
551 }
552
553 /// Check if the TF-MKL dimension ordering map specifies if the input
554 /// tensor is in NCHW format.
IsTensorInNCHWFormat()555 inline bool IsTensorInNCHWFormat() const {
556 TensorFormat data_format = FORMAT_NCHW;
557 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
558 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
559 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
560 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
561 }
562
563 /// Check if the TF-MKL dimension ordering map specifies if the input
564 /// tensor is in NHWC format.
IsTensorInNHWCFormat()565 inline bool IsTensorInNHWCFormat() const {
566 TensorFormat data_format = FORMAT_NHWC;
567 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
568 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
569 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
570 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
571 }
572
573 /// The following methods are used for serializing and de-serializing the
574 /// contents of the mklshape object.
575 /// The data is serialized in this order
576 /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
577
578 /// Size of buffer to hold the serialized object, the size is computed by
579 /// following above mentioned order
GetSerializeBufferSize()580 inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
581
SerializeMklDnnShape(unsigned char * buf,size_t buf_size)582 void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
583 CHECK(buf_size >= GetSerializeBufferSize())
584 << "Buffer size is too small to SerializeMklDnnShape";
585 *reinterpret_cast<MklShapeData*>(buf) = data_;
586 }
587
DeSerializeMklDnnShape(const unsigned char * buf,size_t buf_size)588 void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
589 // Make sure buffer holds at least is_mkl_tensor_.
590 CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
591 << "Buffer size is too small in DeSerializeMklDnnShape";
592
593 const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
594 if (is_mkl_tensor) { // If it is an MKL Tensor then read the rest
595 CHECK(buf_size >= GetSerializeBufferSize())
596 << "Buffer size is too small in DeSerializeMklDnnShape";
597 data_ = *reinterpret_cast<const MklShapeData*>(buf);
598 }
599 }
600 };
601
602 // List of MklShape objects. Used in Concat/Split layers.
603 typedef std::vector<MklDnnShape> MklDnnShapeList;
604
605 template <typename T>
606 class MklDnnData;
607
608 // TODO merge with the execute_primitives.
609 inline void ExecutePrimitive(const std::vector<primitive>& net,
610 const std::vector<MemoryArgsMap>* net_args,
611 const engine& cpu_engine,
612 OpKernelContext* context = nullptr) {
613 DCHECK(net_args);
614 DCHECK_EQ(net.size(), net_args->size());
615 stream* cpu_stream = CreateStream(context, cpu_engine);
616 for (size_t i = 0; i < net.size(); ++i) {
617 net.at(i).execute(*cpu_stream, net_args->at(i));
618 }
619 cpu_stream->wait();
620 delete cpu_stream;
621 }
622 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & input_mkl_tensor,const MklDnnShape & input_mkl_shape,Tensor * output_tf_tensor)623 inline Status ConvertMklToTF(OpKernelContext* context,
624 const Tensor& input_mkl_tensor,
625 const MklDnnShape& input_mkl_shape,
626 Tensor* output_tf_tensor) {
627 try {
628 if (!input_mkl_shape.IsMklTensor()) {
629 // Return input as is since it is already a TF tensor
630 *output_tf_tensor = input_mkl_tensor;
631 return Status::OK();
632 }
633
634 // Allocate output tensor.
635 TensorShape output_tf_shape = input_mkl_shape.GetTfShape();
636 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<T>::v(), output_tf_shape,
637 output_tf_tensor));
638
639 engine cpu_engine(engine::kind::cpu, 0);
640 MklDnnData<T> input(&cpu_engine);
641
642 // Get MKL layout of input tensor.
643 auto input_mkl_md = input_mkl_shape.GetMklLayout();
644 auto output_tf_md = input_mkl_shape.GetTfLayout();
645 input.SetUsrMem(input_mkl_md, &input_mkl_tensor);
646
647 if (input.IsReorderNeeded(output_tf_md)) {
648 std::vector<primitive> net;
649 std::vector<MemoryArgsMap> net_args;
650 bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor,
651 net, net_args, cpu_engine);
652 if (!status) {
653 return Status(error::Code::INTERNAL,
654 "ConvertMklToTF(): Failed to create reorder for input");
655 }
656 ExecutePrimitive(net, &net_args, cpu_engine, context);
657 } else {
658 // If not, just forward input tensor to output tensor.
659 bool status =
660 output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape);
661 if (!status) {
662 return Status(
663 error::Code::INTERNAL,
664 "ConvertMklToTF(): Failed to forward input tensor to output");
665 }
666 }
667 return Status::OK();
668 } catch (mkldnn::error& e) {
669 string error_msg = "Status: " + std::to_string(e.status) +
670 ", message: " + string(e.message) + ", in file " +
671 string(__FILE__) + ":" + std::to_string(__LINE__);
672 LOG(FATAL) << "Operation received an exception: " << error_msg;
673 }
674 }
675
676 // Get the MKL shape from the second string tensor
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape,bool eager_mode)677 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
678 bool eager_mode) {
679 if (!eager_mode) {
680 mklshape->DeSerializeMklDnnShape(
681 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
682 .flat<uint8>()
683 .data(),
684 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
685 .flat<uint8>()
686 .size() *
687 sizeof(uint8));
688 } else {
689 mklshape->SetMklTensor(false);
690 }
691 }
692
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape)693 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
694 GetMklShape(ctext, n, mklshape, false);
695 }
696
697 // Gets the actual input
MklGetInput(OpKernelContext * ctext,int n)698 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
699 return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
700 }
701
GetMklInputList(OpKernelContext * ctext,StringPiece name,OpInputList * input_tensors)702 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
703 OpInputList* input_tensors) {
704 CHECK_NOTNULL(input_tensors);
705 TF_CHECK_OK(ctext->input_list(name, input_tensors));
706 }
707
GetMklShapeList(OpKernelContext * ctext,StringPiece name,MklDnnShapeList * mkl_shapes)708 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
709 MklDnnShapeList* mkl_shapes) {
710 OpInputList input_mkl_tensors;
711 GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
712
713 for (int i = 0; i < input_mkl_tensors.size(); i++) {
714 (*mkl_shapes)[i].DeSerializeMklDnnShape(
715 input_mkl_tensors[i].flat<uint8>().data(),
716 input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
717 }
718 }
719
720 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
721 /// If the input tensor is in MKL layout, then obtains TensorShape from
722 /// MklShape.
723 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
724 bool eager_mode = false) {
725 // Sanity check.
726 CHECK_NOTNULL(context);
727 CHECK_LT(input_idx, context->num_inputs());
728
729 MklDnnShape input_mkl_shape;
730 GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
731 if (input_mkl_shape.IsMklTensor() && !eager_mode) {
732 return input_mkl_shape.GetTfShape();
733 } else {
734 const Tensor& t = MklGetInput(context, input_idx);
735 return t.shape();
736 }
737 }
738
739 // Allocate the second output tensor that will contain
740 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklDnnShape & mkl_shape)741 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
742 const MklDnnShape& mkl_shape) {
743 Tensor* second_tensor = nullptr;
744 TensorShape second_shape;
745 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
746 OP_REQUIRES_OK(ctext, ctext->allocate_output(
747 GetTensorMetaDataIndex(n, ctext->num_outputs()),
748 second_shape, &second_tensor));
749 mkl_shape.SerializeMklDnnShape(
750 second_tensor->flat<uint8>().data(),
751 second_tensor->flat<uint8>().size() * sizeof(uint8));
752 }
753
754 // Allocate the output tensor, create a second output tensor that will contain
755 // the MKL shape serialized
756 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
757 Tensor** output,
758 const TensorShape& tf_shape,
759 const MklDnnShape& mkl_shape,
760 bool eager_mode = false) {
761 OP_REQUIRES_OK(
762 ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
763 tf_shape, output));
764 if (!eager_mode) {
765 Tensor* second_tensor = nullptr;
766 TensorShape second_shape;
767 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
768 OP_REQUIRES_OK(ctext, ctext->allocate_output(
769 GetTensorMetaDataIndex(n, ctext->num_outputs()),
770 second_shape, &second_tensor));
771 mkl_shape.SerializeMklDnnShape(
772 second_tensor->flat<uint8>().data(),
773 second_tensor->flat<uint8>().size() * sizeof(uint8));
774 }
775 }
776
777 // Allocates a temp tensor and returns the data buffer for temporary storage.
778 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,const memory::desc & pd,void ** buf_out)779 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
780 const memory::desc& pd, void** buf_out) {
781 TensorShape tf_shape;
782
783 tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
784 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
785 tf_shape, tensor_out));
786 *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
787 }
788
789 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,TensorShape tf_shape)790 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
791 TensorShape tf_shape) {
792 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
793 tf_shape, tensor_out));
794 }
795
GetStridesFromSizes(MklTensorFormat data_format,size_t * strides,const size_t * sizes)796 inline void GetStridesFromSizes(MklTensorFormat data_format, size_t* strides,
797 const size_t* sizes) {
798 DCHECK_NE(data_format, MklTensorFormat::FORMAT_INVALID);
799 // MKL requires strides in NCHW
800 if (data_format == MklTensorFormat::FORMAT_NHWC) {
801 strides[0] = sizes[2];
802 strides[1] = sizes[0] * sizes[2];
803 strides[2] = 1;
804 strides[3] = sizes[0] * sizes[1] * sizes[2];
805 } else {
806 strides[0] = 1;
807 strides[1] = sizes[0];
808 strides[2] = sizes[0] * sizes[1];
809 strides[3] = sizes[0] * sizes[1] * sizes[2];
810 }
811 }
812
CopyMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)813 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
814 int idx_out) {
815 int num_inputs = context->num_inputs();
816 int num_outputs = context->num_outputs();
817 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
818 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
819 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
820 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
821
822 const Tensor& data = context->input(idx_data_in);
823 const Tensor& meta = context->input(idx_meta_in);
824 Tensor output(data.dtype());
825 Tensor meta_output(meta.dtype());
826
827 // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
828 CHECK(output.CopyFrom(data, data.shape()));
829 CHECK(meta_output.CopyFrom(meta, meta.shape()));
830 context->set_output(idx_data_out, output);
831 context->set_output(idx_meta_out, meta_output);
832 }
833
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)834 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
835 int idx_out,
836 const TensorShape& shape) {
837 int num_inputs = context->num_inputs();
838 int num_outputs = context->num_outputs();
839 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
840 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
841
842 const Tensor& data = context->input(idx_data_in);
843 MklDnnShape mkl_shape_output;
844 mkl_shape_output.SetMklTensor(false);
845 AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
846 Tensor output(data.dtype());
847 // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
848 CHECK(output.CopyFrom(data, shape));
849 context->set_output(idx_data_out, output);
850 }
851
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)852 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
853 int idx_out) {
854 int num_inputs = context->num_inputs();
855 int num_outputs = context->num_outputs();
856 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
857 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
858
859 MklDnnShape dnn_shape_output;
860 dnn_shape_output.SetMklTensor(false);
861 AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
862 if (IsRefType(context->input_dtype(idx_data_in))) {
863 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
864 } else {
865 context->set_output(idx_data_out, context->input(idx_data_in));
866 }
867 }
868
ForwardMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)869 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
870 int idx_out) {
871 int num_inputs = context->num_inputs();
872 int num_outputs = context->num_outputs();
873 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
874 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
875 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
876 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
877
878 if (IsRefType(context->input_dtype(idx_data_in))) {
879 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
880 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
881 } else {
882 context->set_output(idx_data_out, context->input(idx_data_in));
883 context->set_output(idx_meta_out, context->input(idx_meta_in));
884 }
885 }
886
887 // Set a dummy MKLDNN shape (called when the output is in TF format)
SetDummyMklDnnShapeOutput(OpKernelContext * context,uint32 idx_data_out)888 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
889 uint32 idx_data_out) {
890 MklDnnShape mkl_shape_output;
891 mkl_shape_output.SetMklTensor(false);
892 AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
893 }
894
895 // If the input tensor has ref count as 1, it is forwarded to the desired
896 // output port and the function returns true. In that case, it also allocates
897 // the serialized MklDnnShape object. Otherwise, the function returns false.
898 inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
899 int idx_in, int idx_out,
900 Tensor** output,
901 const MklDnnShape& mkl_shape,
902 bool always_forward = true) {
903 int num_inputs = context->num_inputs();
904 int num_outputs = context->num_outputs();
905 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
906 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
907 bool is_forwarded = false;
908 const Tensor& input_tensor = context->input(idx_data_in);
909 const auto output_shape = input_tensor.shape();
910 if (always_forward) {
911 if (IsRefType(context->input_dtype(idx_data_in))) {
912 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
913 } else {
914 context->set_output(idx_data_out, input_tensor);
915 }
916 } else {
917 is_forwarded = context->forward_input_to_output_with_shape(
918 idx_data_in, idx_data_out, output_shape, output);
919 }
920 if (is_forwarded || always_forward) {
921 AllocateOutputSetMklShape(context, idx_out, mkl_shape);
922 return true;
923 }
924 return false;
925 }
926
927 // Forward the MKL shape ONLY (used in elementwise and other ops where
928 // we call the eigen implementation and MKL shape is not used)
ForwardMklMetaDataInToOut(OpKernelContext * context,uint32 idx_data_in,uint32_t idx_data_out)929 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
930 uint32 idx_data_in,
931 uint32_t idx_data_out) {
932 uint32 idx_meta_in =
933 GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
934 uint32 idx_meta_out =
935 GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
936
937 if (IsRefType(context->input_dtype(idx_data_in))) {
938 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
939 } else {
940 context->set_output(idx_meta_out, context->input(idx_meta_in));
941 }
942 }
943
944 // -------------------------------------------------------------------
945 // Common utility functions used by MKL unit tests
946
GetMklMetaTensor()947 inline Tensor GetMklMetaTensor() {
948 MklDnnShape non_mkl_shape;
949 non_mkl_shape.SetMklTensor(false);
950
951 auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
952 Tensor tensor(DT_UINT8, {size});
953
954 non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
955 size * sizeof(uint8));
956 return tensor;
957 }
958
959 // -------------------------------------------------------------------
960
961 /// Return MKL-DNN data type (memory::data_type) for input type T
962 ///
963 /// @input None
964 /// @return memory::data_type corresponding to type T
965 template <typename T>
966 static memory::data_type MklDnnType();
967
968 /// Instantiation for float type. Add similar instantiations for other
969 /// type if needed.
970 template <>
971 memory::data_type MklDnnType<float>() {
972 return memory::data_type::f32;
973 }
974
975 template <>
976 memory::data_type MklDnnType<quint8>() {
977 return memory::data_type::u8;
978 }
979
980 template <>
981 memory::data_type MklDnnType<uint8>() {
982 return memory::data_type::u8;
983 }
984
985 template <>
986 memory::data_type MklDnnType<qint8>() {
987 return memory::data_type::s8;
988 }
989
990 template <>
991 memory::data_type MklDnnType<qint32>() {
992 return memory::data_type::s32;
993 }
994 template <>
995 memory::data_type MklDnnType<bfloat16>() {
996 #ifdef ENABLE_INTEL_MKL_BFLOAT16
997 return memory::data_type::bf16;
998 #else
999 return memory::data_type::f32;
1000 #endif
1001 }
1002
1003 // Map MklTensorFormat to MKL-DNN format tag
1004 //
1005 // @input: MklTensorFormat i.e. TensorFlow data format
1006 // @return: MKL-DNN's memory format tag corresponding to MklTensorFormat.
1007 // Fails with an error if invalid data format.
MklTensorFormatToMklDnnDataFormat(MklTensorFormat format)1008 inline memory::format_tag MklTensorFormatToMklDnnDataFormat(
1009 MklTensorFormat format) {
1010 if (format == MklTensorFormat::FORMAT_NHWC) return memory::format_tag::nhwc;
1011 if (format == MklTensorFormat::FORMAT_NCHW) return memory::format_tag::nchw;
1012 if (format == MklTensorFormat::FORMAT_NDHWC) return memory::format_tag::ndhwc;
1013 if (format == MklTensorFormat::FORMAT_NCDHW) return memory::format_tag::ncdhw;
1014 if (format == MklTensorFormat::FORMAT_X) return memory::format_tag::x;
1015 if (format == MklTensorFormat::FORMAT_NC) return memory::format_tag::nc;
1016 if (format == MklTensorFormat::FORMAT_TNC) return memory::format_tag::tnc;
1017 return memory::format_tag::undef;
1018 }
1019
1020 /// Map TensorFlow data format into MKL-DNN 3D data format
1021 /// @input: TensorFlow data format
1022 /// @return: MKL-DNN 3D data format corresponding to TensorFlow data format;
1023 /// Fails with an error if invalid data format.
TFDataFormatToMklDnn3DDataFormat(TensorFormat format)1024 inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1025 if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC;
1026 if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW;
1027 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1028 return MklTensorFormat::FORMAT_INVALID;
1029 }
1030
1031 /// Map TensorFlow data format into MKL-DNN data format
1032 ///
1033 /// @input: TensorFlow data format
1034 /// @return: MKL-DNN data format corresponding to TensorFlow data format;
1035 /// Fails with an error if invalid data format.
TFDataFormatToMklDnnDataFormat(TensorFormat format)1036 inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1037 if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC;
1038 if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW;
1039 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1040 return MklTensorFormat::FORMAT_INVALID;
1041 }
1042
1043 /// Map MKL-DNN data format into TensorFlow data format
1044 ///
1045 /// @input: MKL-DNN data format
1046 /// @return: Tensorflow data format corresponding to MKL-DNN data format;
1047 /// Fails with an error if invalid data format.
MklDnnDataFormatToTFDataFormat(MklTensorFormat format)1048 inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) {
1049 if (format == MklTensorFormat::FORMAT_NHWC ||
1050 format == MklTensorFormat::FORMAT_NDHWC)
1051 return FORMAT_NHWC;
1052 if (format == MklTensorFormat::FORMAT_NCHW ||
1053 format == MklTensorFormat::FORMAT_NCDHW)
1054 return FORMAT_NCHW;
1055 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1056
1057 // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1058 // that we don't come here.
1059 return FORMAT_NHWC;
1060 }
1061
1062 /// Map TensorShape object into memory::dims required by MKL-DNN
1063 ///
1064 /// This function will simply map input TensorShape into MKL-DNN dims
1065 /// naively. So it will preserve the order of dimensions. E.g., if
1066 /// input tensor is in NHWC format, then dims will be in NHWC format also.
1067 ///
1068 /// @input TensorShape object in shape
1069 /// @return memory::dims corresponding to TensorShape
TFShapeToMklDnnDims(const TensorShape & shape)1070 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1071 memory::dims dims(shape.dims());
1072 for (int d = 0; d < shape.dims(); ++d) {
1073 dims[d] = shape.dim_size(d);
1074 }
1075 return dims;
1076 }
1077
1078 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
1079 ///
1080 /// This function is a specific one than above function. It will map input
1081 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
1082 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1083 /// will be in NCHW format, and not in NHWC format.
1084 ///
1085 /// @input TensorShape object in shape
1086 /// @return memory::dims in MKL-DNN required NCHW format
TFShapeToMklDnnDimsInNCHW(const TensorShape & shape,TensorFormat format)1087 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1088 TensorFormat format) {
1089 // Check validity of format.
1090 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1091 MklTensorFormat::FORMAT_INVALID);
1092
1093 int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1094 int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1095 int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1096 int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1097
1098 // MKL-DNN requires dimensions in NCHW format.
1099 return memory::dims({n, c, h, w});
1100 }
1101
TFShapeToMklDnnDimsInNCDHW(const TensorShape & shape,TensorFormat format)1102 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1103 TensorFormat format) {
1104 // Validate format.
1105 DCHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1106 MklTensorFormat::FORMAT_INVALID);
1107
1108 int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1109 int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1110 int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1111 int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1112 int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1113
1114 // MKL-DNN requires dimensions in NCDHW format.
1115 return memory::dims({n, c, d, h, w});
1116 }
1117
1118 /// Overloaded version of function TFShapeToMklDnnDimsInNCHW above.
1119 /// Input parameters are self-explanatory.
MklDnnDimsInNCHW(const memory::dims & in_dims,TensorFormat format)1120 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1121 TensorFormat format) {
1122 // Validate format.
1123 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1124 MklTensorFormat::FORMAT_INVALID);
1125
1126 int n = in_dims[GetTensorDimIndex(format, 'N')];
1127 int c = in_dims[GetTensorDimIndex(format, 'C')];
1128 int h = in_dims[GetTensorDimIndex(format, 'H')];
1129 int w = in_dims[GetTensorDimIndex(format, 'W')];
1130
1131 // MKL-DNN requires dimensions in NCHW format.
1132 return memory::dims({n, c, h, w});
1133 }
1134
1135 /// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above.
1136 /// Input parameters are self-explanatory.
MklDnnDimsInNCDHW(const memory::dims & in_dims,TensorFormat format)1137 inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims,
1138 TensorFormat format) {
1139 // Validate format.
1140 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1141 MklTensorFormat::FORMAT_INVALID);
1142
1143 int n = in_dims[GetTensorDimIndex<3>(format, 'N')];
1144 int c = in_dims[GetTensorDimIndex<3>(format, 'C')];
1145 int d = in_dims[GetTensorDimIndex<3>(format, '0')];
1146 int h = in_dims[GetTensorDimIndex<3>(format, '1')];
1147 int w = in_dims[GetTensorDimIndex<3>(format, '2')];
1148
1149 // MKL DNN requires dimensions in NCDHW format.
1150 return memory::dims({n, c, d, h, w});
1151 }
1152
1153 /// Map MklDnn memory::dims object into TensorShape object.
1154 ///
1155 /// This function will simply map input shape in MKL-DNN memory::dims format
1156 /// in Tensorflow's TensorShape object by preserving dimension order.
1157 ///
1158 /// @input MKL-DNN memory::dims object
1159 /// @output TensorShape corresponding to memory::dims
MklDnnDimsToTFShape(const memory::dims & dims)1160 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1161 std::vector<int32> shape(dims.size(), -1);
1162 for (int d = 0; d < dims.size(); d++) {
1163 shape[d] = dims[d];
1164 }
1165
1166 TensorShape ret;
1167 CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1168 return ret;
1169 }
1170
1171 /// Function to calculate strides given tensor shape in Tensorflow order
1172 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1173 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
1174 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1175 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1176 ///
1177 /// @input Tensorflow shape in memory::dims type
1178 /// @return memory::dims containing strides for the tensor.
CalculateTFStrides(const memory::dims & dims_tf_order)1179 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1180 CHECK_GT(dims_tf_order.size(), 0);
1181 memory::dims strides(dims_tf_order.size());
1182 int last_dim_idx = dims_tf_order.size() - 1;
1183 strides[last_dim_idx] = 1;
1184 for (int d = last_dim_idx - 1; d >= 0; d--) {
1185 strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1186 }
1187 return strides;
1188 }
1189
1190 /// Helper function to create memory descriptor in Blocked format
1191 ///
1192 /// @input: Tensor dimensions
1193 /// @input: strides corresponding to dimensions. One can use utility
1194 /// function such as CalculateTFStrides to compute strides
1195 /// for given dimensions.
1196 /// @output: mkldnn_memory_desc_t object corresponding to blocked memory
1197 /// format for given dimensions and strides.
1198 /// @return: Status indicating whether the blocked memory descriptor
1199 /// was successfully created.
CreateBlockedMemDescHelper(const memory::dims & dim,const memory::dims & strides,memory::data_type dtype,mkldnn_memory_desc_t * blocked_md)1200 inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
1201 const memory::dims& strides,
1202 memory::data_type dtype,
1203 mkldnn_memory_desc_t* blocked_md) {
1204 DCHECK_EQ(dim.size(), strides.size());
1205 const int kNumDims = dim.size();
1206 mkldnn_dim_t* input_dims = new mkldnn_dim_t[kNumDims];
1207 mkldnn_dim_t* input_strides = new mkldnn_dim_t[kNumDims];
1208 for (int i = 0; i < kNumDims; ++i) {
1209 input_dims[i] = dim[i];
1210 input_strides[i] = strides[i];
1211 }
1212 try {
1213 mkldnn_memory_desc_init_by_strides(blocked_md, kNumDims, input_dims,
1214 memory::convert_to_c(dtype),
1215 input_strides);
1216 delete[] input_dims;
1217 delete[] input_strides;
1218 } catch (mkldnn::error& e) {
1219 delete[] input_dims;
1220 delete[] input_strides;
1221 return Status(error::Code::INTERNAL,
1222 tensorflow::strings::StrCat(
1223 "Failed to create blocked memory descriptor.",
1224 "Status: ", e.status, ", message: ", e.message));
1225 }
1226 return Status::OK();
1227 }
1228
1229 inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
1230 const memory& src_mem,
1231 const memory& dst_mem, const engine& engine,
1232 OpKernelContext* ctx = nullptr) {
1233 std::vector<primitive> net;
1234 net.push_back(mkldnn::reorder(reorder_desc));
1235 std::vector<MemoryArgsMap> net_args;
1236 net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
1237 ExecutePrimitive(net, &net_args, engine, ctx);
1238 }
1239
1240 class MklReorderPrimitive;
1241
1242 template <typename T>
1243 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1244 const memory* to);
1245
1246 // Class to represent all the resources corresponding to a tensor in TensorFlow
1247 // that are required to execute an operation (such as Convolution).
1248 template <typename T>
1249 class MklDnnData {
1250 private:
1251 /// MKL-DNN memory primitive for input user memory
1252 memory* user_memory_;
1253
1254 /// MKL-DNN memory primitive in case input or output reorder is needed.
1255 memory* reorder_memory_;
1256
1257 /// Operations memory descriptor
1258 memory::desc* op_md_;
1259 // flat to indicate if data is 3D or not.
1260 bool bIs3D;
1261 /// Operations temp buffer
1262 void* allocated_buffer_;
1263 /// CPU engine on which operation will be executed
1264 const engine* cpu_engine_;
1265
1266 public:
MklDnnData(const engine * e)1267 explicit MklDnnData(const engine* e)
1268 : user_memory_(nullptr),
1269 reorder_memory_(nullptr),
1270 op_md_(nullptr),
1271 bIs3D(false),
1272 allocated_buffer_(nullptr),
1273 cpu_engine_(e) {}
1274
~MklDnnData()1275 ~MklDnnData() {
1276 if (allocated_buffer_ != nullptr) {
1277 cpu_allocator()->DeallocateRaw(allocated_buffer_);
1278 }
1279 cpu_engine_ = nullptr; // We don't own this.
1280 delete (user_memory_);
1281 delete (reorder_memory_);
1282 delete (op_md_);
1283 }
1284
GetTensorBuffer(const Tensor * tensor)1285 inline void* GetTensorBuffer(const Tensor* tensor) const {
1286 CHECK_NOTNULL(tensor);
1287 return const_cast<void*>(
1288 static_cast<const void*>(tensor->flat<T>().data()));
1289 }
1290
SetIs3DData(bool bIs3D_)1291 void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
GetIs3D()1292 bool GetIs3D() { return bIs3D; }
1293
1294 /// Set user memory primitive using specified dimensions, memory format tag
1295 /// and data_buffer. Function automatically uses element data type by using
1296 /// input type T used for creating call object.
1297 ///
1298 /// In a nutshell, function allows user to describe the input tensor to
1299 /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1300 /// memory format tag HWIO, and the buffer that contains actual values is
1301 /// pointed by data_buffer.
1302 inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1303 void* data_buffer = nullptr) {
1304 auto md = memory::desc(dim, MklDnnType<T>(), fm);
1305 SetUsrMem(md, data_buffer);
1306 }
1307
SetUsrMem(const memory::dims & dim,memory::format_tag fm,const Tensor * tensor)1308 inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1309 const Tensor* tensor) {
1310 DCHECK(tensor);
1311 SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1312 }
1313
1314 /// Helper function to create memory descriptor in Blocked format
1315 ///
1316 /// @input: Tensor dimensions
1317 /// @input: strides corresponding to dimensions. One can use utility
1318 /// function such as CalculateTFStrides to compute strides
1319 /// for given dimensions.
1320 /// @return: memory::desc object corresponding to blocked memory format
1321 /// for given dimensions and strides.
CreateBlockedMemDesc(const memory::dims & dim,const memory::dims & strides)1322 static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1323 const memory::dims& strides) {
1324 mkldnn_memory_desc_t blocked_md;
1325 TF_CHECK_OK(
1326 CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>(), &blocked_md));
1327 return memory::desc(blocked_md);
1328 }
1329
1330 /// A version of SetUsrMem call that allows user to create memory in blocked
1331 /// format. So in addition to accepting dimensions, it also accepts strides.
1332 /// This allows user to create memory for tensor in a format that is not
1333 /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
1334 /// dimensional tensor as a native format. But by using blocked format, a user
1335 /// can create memory for 6D tensor.
1336 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1337 void* data_buffer = nullptr) {
1338 CHECK_EQ(dim.size(), strides.size());
1339 auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1340 SetUsrMem(blocked_md, data_buffer);
1341 }
1342
SetUsrMem(const memory::dims & dim,const memory::dims & strides,const Tensor * tensor)1343 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1344 const Tensor* tensor) {
1345 CHECK_NOTNULL(tensor);
1346 SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1347 }
1348
1349 /// A version of SetUsrMem with memory descriptor and tensor
SetUsrMem(const memory::desc & md,const Tensor * tensor)1350 inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1351 CHECK_NOTNULL(tensor);
1352 SetUsrMem(md, GetTensorBuffer(tensor));
1353 }
1354
1355 /// A version of function to set user memory type that accepts memory
1356 /// descriptor directly, instead of accepting dimensions and format. This
1357 /// function is more generic than the one above, but the function above is
1358 /// sufficient in most cases.
1359 inline void SetUsrMem(const memory::desc& pd, void* data_buffer = nullptr) {
1360 DCHECK(cpu_engine_);
1361 if (user_memory_) delete user_memory_;
1362 // TODO(nhasabni): can we remove dynamic memory allocation?
1363 if (data_buffer) {
1364 user_memory_ = new memory(pd, *cpu_engine_, data_buffer);
1365 } else {
1366 user_memory_ = new memory(pd, *cpu_engine_);
1367 }
1368 }
1369
1370 /// Get function for user memory primitive.
GetUsrMem()1371 inline const memory* GetUsrMem() const { return user_memory_; }
1372
1373 /// Get function for descriptor of user memory.
GetUsrMemDesc()1374 inline memory::desc GetUsrMemDesc() const {
1375 DCHECK(user_memory_);
1376 return user_memory_->get_desc();
1377 }
1378
1379 /// Get function for data buffer of user memory primitive.
GetUsrMemDataHandle()1380 inline void* GetUsrMemDataHandle() const {
1381 CHECK_NOTNULL(user_memory_);
1382 return user_memory_->get_data_handle();
1383 }
1384
1385 /// Set function for data buffer of user memory primitive.
1386 inline void SetUsrMemDataHandle(void* data_buffer,
1387 std::shared_ptr<stream> t_stream = nullptr) {
1388 CHECK_NOTNULL(user_memory_);
1389 CHECK_NOTNULL(data_buffer);
1390 #ifdef ENABLE_MKLDNN_THREADPOOL
1391 user_memory_->set_data_handle(data_buffer, *t_stream);
1392 #else
1393 user_memory_->set_data_handle(data_buffer);
1394 #endif // ENABLE_MKLDNN_THREADPOOL
1395 }
1396
1397 /// Set function for data buffer of user memory primitive.
1398 inline void SetUsrMemDataHandle(const Tensor* tensor,
1399 std::shared_ptr<stream> t_stream = nullptr) {
1400 SetUsrMemDataHandle(GetTensorBuffer(tensor), t_stream);
1401 }
1402
1403 /// allocate function for data buffer
AllocateBuffer(size_t size)1404 inline void AllocateBuffer(size_t size) {
1405 const int64 kMemoryAlignment = 64; // For AVX512 memory alignment.
1406 allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlignment, size);
1407 }
1408
GetAllocatedBuffer()1409 inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1410
1411 /// Get the memory primitive for input and output of an op. If inputs
1412 /// to an op require reorders, then this function returns memory primitive
1413 /// for reorder. Otherwise, it will return memory primitive for user memory.
1414 ///
1415 /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1416 /// execute Conv2D, we need memory primitive for I and F. But if reorder is
1417 /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1418 /// primitive for F), then we need I_r and F_r to perform Conv2D.
GetOpMem()1419 inline const memory& GetOpMem() const {
1420 return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1421 }
1422
1423 /// Set memory descriptor of an operation in terms of dimensions and memory
1424 /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1425 /// but memory::format_tag would be mkldnn::any because we want MKL-DNN to
1426 /// choose the best layout/format for given input dimensions.
SetOpMemDesc(const memory::dims & dim,memory::format_tag fm)1427 inline void SetOpMemDesc(const memory::dims& dim, memory::format_tag fm) {
1428 // TODO(nhasabni): can we remove dynamic memory allocation?
1429 op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1430 }
1431
1432 /// Get function for memory descriptor for an operation
GetOpMemDesc()1433 inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1434
1435 /// Predicate that checks if we need to reorder user's memory into memory
1436 /// pointed by op_md.
1437 ///
1438 /// @input: op_md - memory descriptor of the given input of an operation.
1439 /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::desc & op_pd)1440 inline bool IsReorderNeeded(const memory::desc& op_pd) const {
1441 DCHECK(user_memory_);
1442 return op_pd != user_memory_->get_desc();
1443 }
1444
1445 /// Function to create a reorder from memory pointed by from to memory pointed
1446 /// by to. Returns created primitive.
CreateReorder(const memory * from,const memory * to)1447 inline primitive CreateReorder(const memory* from, const memory* to) const {
1448 CHECK_NOTNULL(from);
1449 CHECK_NOTNULL(to);
1450 return reorder(*from, *to);
1451 }
1452
1453 /// Function to handle input reordering
1454 ///
1455 /// Check if we need to reorder this input of an operation.
1456 /// Return true and allocate reorder memory primitive if reorder is needed.
1457 /// Otherwise, return false and do not allocate reorder memory primitive.
1458 ///
1459 /// To check if reorder is needed, this function compares memory primitive
1460 /// descriptor (memory descriptor for v1.x) of an operation (op_pd) for
1461 /// the given input with the user-specified memory descriptor.
1462 ///
1463 /// @input: op_pd - memory primitive descriptor of the given input of an
1464 /// operation
1465 /// @input: net - net to which to add reorder primitive in case it is needed.
1466 /// @input: net_args - net to which user and reorder memories are added if
1467 /// needed. Each entry is a key-value pair of the form
1468 /// <argument-type, mkldnn::memory>.
1469 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1470 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1471 std::vector<primitive>& net,
1472 std::vector<MemoryArgsMap>& net_args,
1473 const engine& engine) {
1474 DCHECK(user_memory_);
1475 DCHECK_EQ(net.size(), net_args.size());
1476 if (IsReorderNeeded(op_md)) {
1477 // TODO(nhasabni): can we remove dynamic memory allocation?
1478 reorder_memory_ = new memory(op_md, engine);
1479 net.push_back(CreateReorder(user_memory_, reorder_memory_));
1480 net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1481 {MKLDNN_ARG_TO, *reorder_memory_}});
1482 return true;
1483 }
1484 return false;
1485 }
1486
1487 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1488 const engine& engine,
1489 OpKernelContext* context = nullptr) {
1490 DCHECK(user_memory_);
1491 if (IsReorderNeeded(op_md)) {
1492 // TODO(nhasabni): can we remove dynamic memory allocation?
1493 // primitive reuse don't allow two same reorder prim in
1494 // one stream, so submit it immediately
1495 reorder_memory_ = new memory(op_md, engine);
1496 auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1497 std::shared_ptr<stream> cpu_stream;
1498 cpu_stream.reset(CreateStream(context, prim->GetEngine()));
1499 std::vector<primitive> net;
1500 net.push_back(*(prim->GetPrimitive()));
1501 std::vector<MemoryArgsMap> net_args;
1502 net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1503 {MKLDNN_ARG_TO, *reorder_memory_}});
1504 execute_primitives(net, cpu_stream, net_args);
1505 return true;
1506 }
1507 return false;
1508 }
1509
1510 /// Overloaded version of above function that accepts memory buffer
1511 /// where output of reorder needs to be stored.
1512 ///
1513 /// @input: op_pd - memory primitive descriptor (memory descriptor for v1.x)
1514 /// of the given input of an operation
1515 /// @reorder_data_handle - memory buffer where output of reorder needs to be
1516 /// stored. Primitive does not check if buffer has
1517 /// enough size to write.
1518 /// @input: net - net to which to add reorder primitive in case it is needed.
1519 /// @input: net_args - net to which user and reorder memories are added if
1520 /// needed. Each entry is a key-value pair of the form
1521 /// <argument-type, mkldnn::memory>.
1522 /// @input: engine - MKL-DNN's abstraction of a computational device
1523 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,void * reorder_data_handle,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1524 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1525 void* reorder_data_handle,
1526 std::vector<primitive>& net,
1527 std::vector<MemoryArgsMap>& net_args,
1528 const engine& engine) {
1529 DCHECK(reorder_data_handle);
1530 DCHECK(user_memory_);
1531 if (IsReorderNeeded(op_md)) {
1532 // TODO(nhasabni): can we remove dynamic memory allocation?
1533 reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1534 net.push_back(CreateReorder(user_memory_, reorder_memory_));
1535 net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *user_memory_},
1536 {MKLDNN_ARG_TO, *reorder_memory_}});
1537 return true;
1538 }
1539 return false;
1540 }
1541
1542 /// This is a faster path with reorder primitive cache compared with
1543 /// CheckReorderToOpMem(..., std::vector<primitive>* net).
1544 /// The slower path will be removed in the future
1545 /// TODO(bhavanis): Need to use reorder cache here for better performance.
1546 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1547 void* reorder_data_handle,
1548 const engine& engine,
1549 OpKernelContext* context = nullptr) {
1550 DCHECK(reorder_data_handle);
1551 DCHECK(user_memory_);
1552 if (IsReorderNeeded(op_md)) {
1553 // TODO(nhasabni): can we remove dynamic memory allocation?
1554 // primitive reuse don't allow two same reorder prim in
1555 // one stream, so submit it immediately
1556 reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1557 auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1558 std::shared_ptr<stream> cpu_stream;
1559 cpu_stream.reset(CreateStream(context, prim->GetEngine()));
1560 std::vector<primitive> net;
1561 net.push_back(*(prim->GetPrimitive()));
1562 std::vector<MemoryArgsMap> net_args;
1563 net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
1564 {MKLDNN_ARG_TO, *reorder_memory_}});
1565 execute_primitives(net, cpu_stream, net_args);
1566 return true;
1567 }
1568 return false;
1569 }
1570
1571 /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1572 /// where output of reorder needs to be stored.
1573 ///
1574 /// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
1575 /// of the given input of an operation
1576 /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1577 /// reorder. Primitive does not check if buffer is
1578 /// enough size to write.
1579 /// @input: net - net to which to add reorder primitive in case it is needed.
1580 /// @input: net_args - net to which user and reorder memories are added if
1581 /// needed. Each entry is a key-value pair of the form
1582 /// <argument-type, mkldnn::memory>.
1583 /// @input: engine - MKL-DNN's abstraction of a computational device
1584 /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,Tensor * reorder_tensor,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1585 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1586 Tensor* reorder_tensor,
1587 std::vector<primitive>& net,
1588 std::vector<MemoryArgsMap>& net_args,
1589 const engine& engine) {
1590 DCHECK(reorder_tensor);
1591 return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
1592 net_args, engine);
1593 }
1594
1595 /// TODO: this is a faster path with reorder primitive cache compared with
1596 /// CheckReorderToOpMem(op_md, reorder_tensor, net, net_args, engine), will
1597 /// remove
1598 /// slow path in the future
1599 inline bool CheckReorderToOpMem(const memory::desc& op_pd,
1600 Tensor* reorder_tensor,
1601 OpKernelContext* ctx = nullptr) {
1602 DCHECK(reorder_tensor);
1603 return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
1604 *cpu_engine_, ctx);
1605 }
1606
1607 /// Function to handle output reorder
1608 ///
1609 /// This function performs very similar functionality as input reordering
1610 /// function above. The only difference is that this function does not add
1611 /// reorder primitive to the net. The reason for this is: the reorder
1612 /// primitive for output needs to be added to the list only after operation
1613 /// has executed. But we need to prepare a temporary buffer in case output
1614 /// reorder is needed. And this temporary buffer will hold the output of
1615 /// an operation before it is fed to reorder primitive.
1616 ///
1617 /// @input - memory primitive descriptor (memory descriptor for v1.x) for the
1618 /// given output of an operation
1619 /// @return: true in case reorder of output is needed; false, otherwise.
PrepareReorderToUserMemIfReq(const memory::desc & op_pd)1620 inline bool PrepareReorderToUserMemIfReq(const memory::desc& op_pd) {
1621 DCHECK(user_memory_);
1622 if (IsReorderNeeded(op_pd)) {
1623 // TODO(nhasabni): can we remove dynamic memory allocation?
1624 reorder_memory_ = new memory(op_pd, *cpu_engine_);
1625 return true;
1626 }
1627 return false;
1628 }
1629
1630 /// Function to actually insert reorder primitive in the net
1631 ///
1632 /// This function completes remaining part of output reordering. It inserts
1633 /// a reordering primitive from the temporary buffer that holds the output
1634 /// to the user-specified output buffer.
1635 ///
1636 /// @input: net - net to which to add reorder primitive
1637 /// @input: net_args - net to which user and reorder memories are added if
1638 /// needed. Each entry is a key-value pair of the form
1639 /// <argument-type, mkldnn::memory>.
InsertReorderToUserMem(std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args)1640 inline void InsertReorderToUserMem(std::vector<primitive>& net,
1641 std::vector<MemoryArgsMap>& net_args) {
1642 DCHECK(user_memory_);
1643 DCHECK(reorder_memory_);
1644 net.push_back(CreateReorder(reorder_memory_, user_memory_));
1645 net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
1646 {MKLDNN_ARG_TO, *user_memory_}});
1647 }
1648
1649 /// TODO: this is a faster path with reorder primitive cache compared with
1650 /// InsertReorderToUserMem(net, net_args), will remove
1651 /// slow path in the future
1652 inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
1653 DCHECK(user_memory_);
1654 DCHECK(reorder_memory_);
1655 DCHECK(cpu_engine_);
1656 // primitive reuse don't allow two same reorder prim in
1657 // one stream, so submit it immediately
1658 std::vector<primitive> net;
1659 auto* prim = FindOrCreateReorder<T>(reorder_memory_, user_memory_);
1660 net.push_back(*(prim->GetPrimitive()));
1661 std::vector<MemoryArgsMap> net_args;
1662 net_args.push_back(
1663 {{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
1664 std::shared_ptr<stream> cpu_stream;
1665 cpu_stream.reset(CreateStream(ctx, prim->GetEngine()));
1666 execute_primitives(net, cpu_stream, net_args);
1667 }
1668 };
1669
1670 /// Base class for operations with reuse of primitives
1671 class MklPrimitive {
1672 public:
~MklPrimitive()1673 virtual ~MklPrimitive() {}
MklPrimitive()1674 MklPrimitive() {}
MklPrimitive(const engine & cpu_engine)1675 MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
1676 // Dummy data which MKL DNN never operates on
1677 unsigned char* DummyData = nullptr;
1678 engine cpu_engine_ = engine(engine::kind::cpu, 0);
GetEngine()1679 const engine& GetEngine() { return cpu_engine_; }
1680 };
1681
1682 const mkldnn::memory::dims NONE_DIMS = {};
1683
1684 //
1685 // LRUCache is a class which implements LRU (Least Recently Used) cache.
1686 // The implementation is similar to that of
1687 // tensorflow/core/platform/cloud/expiring_lru_cache.h
1688 // without its thread-safe part because the cache is supposed to be
1689 // used as thread local (for instance, MklPrimitive caching).
1690 //
1691 // The LRU list maintains objects in chronological order based on
1692 // creation time, with the least recently accessed object at the
1693 // tail of LRU list, while the most recently accessed object
1694 // at the head of LRU list.
1695 //
1696 // This class is used to maintain an upper bound on the total number of
1697 // cached items. When the cache reaches its capacity, the LRU item will
1698 // be removed and replaced by a new one from SetOp call.
1699 //
1700 template <typename T>
1701 class LRUCache {
1702 public:
LRUCache(size_t capacity)1703 explicit LRUCache(size_t capacity) {
1704 capacity_ = capacity;
1705 Clear();
1706 }
1707
GetOp(const string & key)1708 T* GetOp(const string& key) {
1709 auto it = cache_.find(key);
1710 if (it == cache_.end()) {
1711 return nullptr;
1712 }
1713
1714 // Move to the front of LRU list as the most recently accessed.
1715 lru_list_.erase(it->second.lru_iterator);
1716 lru_list_.push_front(it->first);
1717 it->second.lru_iterator = lru_list_.begin();
1718 return it->second.op;
1719 }
1720
SetOp(const string & key,T * op)1721 void SetOp(const string& key, T* op) {
1722 if (lru_list_.size() >= capacity_) {
1723 Delete();
1724 }
1725
1726 // Insert an entry to the front of the LRU list
1727 lru_list_.push_front(key);
1728 Entry entry(op, lru_list_.begin());
1729 cache_.emplace(std::make_pair(key, std::move(entry)));
1730 }
1731
Clear()1732 void Clear() {
1733 if (lru_list_.empty()) return;
1734
1735 // Clean up the cache
1736 cache_.clear();
1737 lru_list_.clear();
1738 }
1739
1740 private:
1741 struct Entry {
1742 // The entry's value.
1743 T* op;
1744
1745 // A list iterator pointing to the entry's position in the LRU list.
1746 std::list<string>::iterator lru_iterator;
1747
1748 // Constructor
EntryEntry1749 Entry(T* op, std::list<string>::iterator it) {
1750 this->op = op;
1751 this->lru_iterator = it;
1752 }
1753
1754 // Move constructor
EntryEntry1755 Entry(Entry&& source) noexcept
1756 : lru_iterator(std::move(source.lru_iterator)) {
1757 op = std::move(source.op);
1758 source.op = std::forward<T*>(nullptr);
1759 }
1760
1761 // Destructor
~EntryEntry1762 ~Entry() {
1763 if (op != nullptr) delete op;
1764 }
1765 };
1766
1767 // Remove the least recently accessed entry from LRU list, which
1768 // is the tail of lru_list_. Update cache_ correspondingly.
Delete()1769 bool Delete() {
1770 if (lru_list_.empty()) return false;
1771 string key = lru_list_.back();
1772 lru_list_.pop_back();
1773 cache_.erase(key);
1774 return true;
1775 }
1776
1777 // Cache capacity
1778 size_t capacity_;
1779
1780 // The cache, a map from string key to a LRU entry.
1781 std::unordered_map<string, Entry> cache_;
1782
1783 // The LRU list of entries.
1784 // The front of the list contains the key of the most recently accessed
1785 // entry, while the back of the list is the least recently accessed entry.
1786 std::list<string> lru_list_;
1787 };
1788
1789 template <typename T>
1790 class MklPrimitiveFactory {
1791 public:
MklPrimitiveFactory()1792 MklPrimitiveFactory() {}
1793
~MklPrimitiveFactory()1794 ~MklPrimitiveFactory() {}
1795
GetOp(const string & key)1796 MklPrimitive* GetOp(const string& key) {
1797 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1798 return lru_cache.GetOp(key);
1799 }
1800
SetOp(const string & key,MklPrimitive * op)1801 void SetOp(const string& key, MklPrimitive* op) {
1802 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1803 lru_cache.SetOp(key, op);
1804 }
1805
1806 /// Function to decide whether HW has AVX512 or AVX2
1807 /// For those legacy device(w/o AVX512 and AVX2),
1808 /// MKL-DNN GEMM will be used.
IsLegacyPlatform()1809 static inline bool IsLegacyPlatform() {
1810 return (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
1811 !port::TestCPUFeature(port::CPUFeature::AVX2));
1812 }
1813
1814 /// Function to check whether primitive memory optimization is enabled
IsPrimitiveMemOptEnabled()1815 static inline bool IsPrimitiveMemOptEnabled() {
1816 bool is_primitive_mem_opt_enabled = true;
1817 TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
1818 &is_primitive_mem_opt_enabled));
1819 return is_primitive_mem_opt_enabled;
1820 }
1821
1822 private:
GetLRUCache()1823 static inline LRUCache<MklPrimitive>& GetLRUCache() {
1824 static const int kCapacity = 1024; // cache capacity
1825 static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
1826 return lru_cache_;
1827 }
1828 };
1829
1830 // utility class for creating keys of MKL primitive pool.
1831 class FactoryKeyCreator {
1832 public:
FactoryKeyCreator()1833 FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
1834
~FactoryKeyCreator()1835 ~FactoryKeyCreator() {}
1836
AddAsKey(const string & str)1837 void AddAsKey(const string& str) { Append(str); }
1838
AddAsKey(const mkldnn::memory::dims & dims)1839 void AddAsKey(const mkldnn::memory::dims& dims) {
1840 for (unsigned int i = 0; i < dims.size(); i++) {
1841 AddAsKey<int>(dims[i]);
1842 }
1843 }
1844
1845 template <typename T>
AddAsKey(const T data)1846 void AddAsKey(const T data) {
1847 auto buffer = reinterpret_cast<const char*>(&data);
1848 Append(StringPiece(buffer, sizeof(T)));
1849 }
1850
GetKey()1851 string GetKey() { return key_; }
1852
1853 private:
1854 string key_;
1855 const char delimiter = 'x';
1856 const int kMaxKeyLength = 256;
Append(StringPiece s)1857 void Append(StringPiece s) {
1858 key_.append(string(s));
1859 key_.append(1, delimiter);
1860 }
1861 };
1862
1863 class MklReorderPrimitive : public MklPrimitive {
1864 public:
MklReorderPrimitive(const memory * from,const memory * to)1865 explicit MklReorderPrimitive(const memory* from, const memory* to)
1866 : MklPrimitive(engine(engine::kind::cpu, 0)) {
1867 Setup(from, to);
1868 }
~MklReorderPrimitive()1869 ~MklReorderPrimitive() {}
1870
GetPrimitive()1871 std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
1872
SetMemory(const memory * from,const memory * to)1873 void SetMemory(const memory* from, const memory* to) {
1874 context_.src_mem->set_data_handle(from->get_data_handle());
1875 context_.dst_mem->set_data_handle(to->get_data_handle());
1876 }
1877
GetStream()1878 std::shared_ptr<mkldnn::stream> GetStream() { return stream_; }
1879
1880 private:
1881 struct ReorderContext {
1882 std::shared_ptr<mkldnn::memory> src_mem;
1883 std::shared_ptr<mkldnn::memory> dst_mem;
1884 std::shared_ptr<primitive> reorder_prim;
ReorderContextReorderContext1885 ReorderContext()
1886 : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
1887 } context_;
1888
1889 std::shared_ptr<mkldnn::stream> stream_;
1890
Setup(const memory * from,const memory * to)1891 void Setup(const memory* from, const memory* to) {
1892 context_.src_mem.reset(
1893 new memory(from->get_desc(), cpu_engine_, DummyData));
1894 context_.dst_mem.reset(new memory(to->get_desc(), cpu_engine_, DummyData));
1895 context_.reorder_prim = std::make_shared<mkldnn::reorder>(
1896 reorder(*context_.src_mem, *context_.dst_mem));
1897 stream_.reset(new stream(cpu_engine_));
1898 }
1899 };
1900
1901 template <typename T>
1902 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
1903 public:
Get(const memory * from,const memory * to)1904 static MklReorderPrimitive* Get(const memory* from, const memory* to) {
1905 auto reorderPrim = static_cast<MklReorderPrimitive*>(
1906 MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
1907 if (reorderPrim == nullptr) {
1908 reorderPrim = new MklReorderPrimitive(from, to);
1909 MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
1910 reorderPrim);
1911 }
1912 reorderPrim->SetMemory(from, to);
1913 return reorderPrim;
1914 }
1915
GetInstance()1916 static MklReorderPrimitiveFactory& GetInstance() {
1917 static MklReorderPrimitiveFactory instance_;
1918 return instance_;
1919 }
1920
CreateKey(const memory * from,const memory * to)1921 static string CreateKey(const memory* from, const memory* to) {
1922 string prefix = "reorder";
1923 FactoryKeyCreator key_creator;
1924 auto const& from_desc = from->get_desc().data;
1925 auto const& to_desc = to->get_desc().data;
1926 memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
1927 memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
1928 auto from_strides = from_desc.format_desc.blocking.strides;
1929
1930 // As DNNL memory desc has C style array and only init the used
1931 // part, so need use the valid part as key.
1932 auto from_inner_nblks = from_desc.format_desc.blocking.inner_nblks;
1933 auto from_inner_blks = from_desc.format_desc.blocking.inner_blks;
1934 auto from_inner_idxs = from_desc.format_desc.blocking.inner_idxs;
1935 memory::dims from_inner_blks_1(from_inner_blks,
1936 &from_inner_blks[from_inner_nblks]);
1937 memory::dims from_inner_idxs_1(from_inner_idxs,
1938 &from_inner_idxs[from_inner_nblks]);
1939 auto to_inner_nblks = to_desc.format_desc.blocking.inner_nblks;
1940 auto to_inner_blks = to_desc.format_desc.blocking.inner_blks;
1941 auto to_inner_idxs = to_desc.format_desc.blocking.inner_idxs;
1942 memory::dims to_inner_blks_1(to_inner_blks, &to_inner_blks[to_inner_nblks]);
1943 memory::dims to_inner_idxs_1(to_inner_idxs, &to_inner_idxs[to_inner_nblks]);
1944
1945 auto to_strides = to_desc.format_desc.blocking.strides;
1946 memory::dims from_strides_outer_blocks(from_strides,
1947 &from_strides[from_desc.ndims]);
1948 memory::dims to_strides_outer_blocks(to_strides,
1949 &to_strides[to_desc.ndims]);
1950
1951 key_creator.AddAsKey(prefix);
1952 key_creator.AddAsKey(static_cast<int>(from_desc.extra.flags));
1953 key_creator.AddAsKey(static_cast<int>(from_inner_nblks));
1954 key_creator.AddAsKey(from_inner_blks_1);
1955 key_creator.AddAsKey(from_inner_idxs_1);
1956 key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
1957 key_creator.AddAsKey(from_dims);
1958 key_creator.AddAsKey(from_strides_outer_blocks);
1959 key_creator.AddAsKey(static_cast<int>(to_desc.extra.flags));
1960 key_creator.AddAsKey(static_cast<int>(to_inner_nblks));
1961 key_creator.AddAsKey(to_inner_blks_1);
1962 key_creator.AddAsKey(to_inner_idxs_1);
1963 key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
1964 key_creator.AddAsKey(to_dims);
1965 key_creator.AddAsKey(to_strides_outer_blocks);
1966 return key_creator.GetKey();
1967 }
1968
1969 private:
MklReorderPrimitiveFactory()1970 MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory()1971 ~MklReorderPrimitiveFactory() {}
1972
GetReorder(const memory * from,const memory * to)1973 MklPrimitive* GetReorder(const memory* from, const memory* to) {
1974 string key = CreateKey(from, to);
1975 return this->GetOp(key);
1976 }
1977
SetReorder(const memory * from,const memory * to,MklPrimitive * op)1978 void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
1979 string key = CreateKey(from, to);
1980 this->SetOp(key, op);
1981 }
1982 };
1983
1984 /// Function to find(or create) a reorder from memory pointed by
1985 /// from to memory pointed by to, it will created primitive or
1986 /// get primitive from pool if it is cached.
1987 /// Returns the primitive.
1988 template <typename T>
FindOrCreateReorder(const memory * from,const memory * to)1989 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1990 const memory* to) {
1991 CHECK_NOTNULL(from);
1992 CHECK_NOTNULL(to);
1993 MklReorderPrimitive* reorder_prim =
1994 MklReorderPrimitiveFactory<T>::Get(from, to);
1995 return reorder_prim;
1996 }
1997
1998 // utility function to determine if it is conv 1x1 and stride != 1
1999 // for purpose of temporarily disabling primitive reuse
IsConv1x1StrideNot1(memory::dims filter_dims,memory::dims strides)2000 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2001 memory::dims strides) {
2002 if (filter_dims.size() != 4 || strides.size() != 2) return false;
2003
2004 return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2005 ((strides[0] != 1) || (strides[1] != 1)));
2006 }
2007
2008 } // namespace tensorflow
2009
2010 /////////////////////////////////////////////////////////////////////
2011 // Macros for handling registration for various types
2012 /////////////////////////////////////////////////////////////////////
2013
2014 #define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
2015
2016 #ifdef ENABLE_INTEL_MKL_BFLOAT16
2017 #define REGISTER_TEST_BFLOAT16(TEST) \
2018 REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
2019
2020 #define REGISTER_TEST_ALL_TYPES(TEST) \
2021 REGISTER_TEST_FLOAT32(TEST); \
2022 REGISTER_TEST_BFLOAT16(TEST);
2023 #else
2024 #define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
2025 #endif // ENABLE_INTEL_MKL_BFLOAT16
2026
2027 #endif // INTEL_MKL
2028 #endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2029