• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
17 #define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
18 #ifdef INTEL_MKL
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 
22 namespace tensorflow {
23 // Since our ops are going to produce and also consume N addition tensors
24 // (Mkl) for N Tensorflow tensors, we can have following different
25 // orderings among these 2N tensors.
26 //
27 // E.g., for Tensorflow tensors A, B, and C, our ops will produce and
28 // consume A_m, B_m, and C_m additionally.
29 //
30 // INTERLEAVED: in this case 2N tensors are interleaved. So for above
31 //              example, the ordering looks like: A, A_m, B, B_m, C, C_m.
32 //
33 // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
34 //             by N Mkl tensors. So for above example, the ordering looks
35 //             like: A, B, C, A_m, B_m, C_m
36 //
37 // Following APIs map index of original Tensorflow tensors to their
38 // appropriate position based on selected ordering. For contiguous ordering,
39 // we need to know the total number of tensors (parameter total).
40 //
41 typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
42 // NOTE: Currently, we use contiguous ordering. If you change this, then you
43 // would need to change Mkl op definitions in nn_ops.cc.
44 static const MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
45 
46 // Get index of MetaData tensor from index 'n' of Data tensor.
DataIndexToMetaDataIndex(int n,int total_tensors)47 inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
48   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
49     // For interleaved ordering, Mkl tensor follows immediately after
50     // Tensorflow tensor.
51     return n + 1;
52   } else {
53     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
54     // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
55     return n + total_tensors / 2;
56   }
57 }
58 
GetTensorDataIndex(int n,int total_tensors)59 int inline GetTensorDataIndex(int n, int total_tensors) {
60   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
61     return 2 * n;  // index corresponding to nth input/output tensor
62   } else {
63     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
64     return n;
65   }
66 }
67 
GetTensorMetaDataIndex(int n,int total_tensors)68 int inline GetTensorMetaDataIndex(int n, int total_tensors) {
69   // Get index for TensorData first and then use mapping function
70   // to get TensorMetaData index from TensorData index.
71   int tidx = GetTensorDataIndex(n, total_tensors);
72   return DataIndexToMetaDataIndex(tidx, total_tensors);
73 }
74 
75 namespace mkl_op_registry {
76 static const char* kMklOpLabel = "MklOp";
77 static const char* kMklOpLabelPattern = "label='MklOp'";
78 static const char* kMklQuantizedOpLabel = "QuantizedMklOp";
79 static const char* kMklQuantizedOpLabelPattern = "label='QuantizedMklOp'";
80 // Prefix that we add to Tensorflow op name to construct Mkl op name.
81 static const char* const kMklOpPrefix = "_Mkl";
82 
83 // Get the name of Mkl op from original TensorFlow op
84 // We prefix 'Mkl' to the original op to get Mkl op.
GetMklOpName(const string & name)85 inline string GetMklOpName(const string& name) {
86   return string(kMklOpPrefix) + name;
87 }
88 
89 // Check whether opname with type T is registered as MKL-compliant.
90 //
91 // @input: name of the op
92 // @input: T datatype to be used for checking op
93 // @return: true if opname is registered as Mkl op; false otherwise
IsMklOp(const string & op_name,DataType T)94 static inline bool IsMklOp(const string& op_name, DataType T) {
95   string kernel = KernelsRegisteredForOp(op_name);
96 
97   // Restrict quantized ops to QUINT8 and QINT8 for now
98   if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
99     return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
100   }
101   // Restrict regular ops to FLOAT
102   if (kernel.find(kMklOpLabelPattern) != string::npos) {
103     return (T == DT_FLOAT);
104   }
105   return false;
106 }
107 
108 // TODO(mdfaijul): QuantizedConv2D is registered with input: QUINT8
109 // filter:QINT8 for mkldnn integration. First a dummy kernel is created
110 // and then it is replaced by an actual kernel.
IsMklOp(const string & op_name,DataType Tinput,DataType Tfilter)111 static inline bool IsMklOp(const string& op_name, DataType Tinput,
112                            DataType Tfilter) {
113   string kernel = KernelsRegisteredForOp(op_name);
114 
115   // Restrict quantized ops to QUINT8 and QINT8 for now
116   if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
117     return (Tinput == DT_QUINT8 && Tfilter == DT_QINT8);
118   }
119   return false;
120 }
121 
122 // Check whether opname with type T is registered as MKL-compliant and
123 // is element-wise.
124 //
125 // @input: name of the op
126 // @input: T datatype to be used for checking op
127 // @return: true if opname is registered as element-wise Mkl op;
128 // false otherwise
IsMklElementWiseOp(const string & op_name,DataType T)129 static inline bool IsMklElementWiseOp(const string& op_name, DataType T) {
130   if (!IsMklOp(op_name, T)) {
131     return false;
132   }
133   bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
134                  0 == op_name.compare(GetMklOpName("Sub")) ||
135                  0 == op_name.compare(GetMklOpName("Mul")) ||
136                  0 == op_name.compare(GetMklOpName("Maximum")) ||
137                  0 == op_name.compare(GetMklOpName("SquaredDifference")));
138 
139   return result;
140 }
141 }  // namespace mkl_op_registry
142 }  // namespace tensorflow
143 #endif  // INTEL_MKL
144 #endif  // TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
145