• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef INTEL_MKL
17 
18 #include <algorithm>
19 #include <vector>
20 #include "tensorflow/core/framework/numeric_op.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/platform/byte_order.h"
28 #include "tensorflow/core/platform/cpu_info.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/util/tensor_format.h"
31 
32 #include "mkldnn.hpp"
33 #include "tensorflow/core/kernels/mkl_tfconv_op.h"
34 #include "tensorflow/core/util/mkl_util.h"
35 using mkldnn::stream;
36 
37 namespace tensorflow {
38 typedef Eigen::ThreadPoolDevice CPUDevice;
39 
40 ///////////////////////////////////////////////////////////
41 //               Op kernel
42 // Checks and ensures that the 2 inputs are compatible for mkl binary ops.
43 // Here's the basic logic:
44 //
45 // if both inputs are in TF format:
46 //   pass the inputs through to the output
47 // else if both inputs are in mkl format:
48 //   if both have the same shape:
49 //     pass the inputs through to the output
50 //   else:
51 //     convert both to TF
52 // else if one is TF and one is MKL:
53 //   if broadcast is needed:
54 //     convert the MKL format input to TF format
55 //   else:
56 //     convert the TF format input to MKL format
57 ///////////////////////////////////////////////////////////
58 
59 template <typename Device, typename T>
60 class MklInputConversionOp : public OpKernel {
61  public:
MklInputConversionOp(OpKernelConstruction * context)62   explicit MklInputConversionOp(OpKernelConstruction* context)
63       : OpKernel(context) {
64     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
65     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
66     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
67   }
68 
69  private:
Compute(OpKernelContext * context)70   void Compute(OpKernelContext* context) override {
71     const int kInputIndex_0 = 0, kInputIndex_1 = 1;
72     const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0);
73     MklDnnShape input_shape_0;
74     GetMklShape(context, kInputIndex_0, &input_shape_0);
75 
76     const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1);
77     MklDnnShape input_shape_1;
78     GetMklShape(context, kInputIndex_1, &input_shape_1);
79 
80     VLOG(1) << "MklInputConversionOp: Input shapes are: "
81             << context->input(kInputIndex_0).shape().DebugString() << " and "
82             << context->input(kInputIndex_1).shape().DebugString();
83 
84     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
85     // if both inputs are in TF format, just copy input tensors to output.
86     if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
87       VLOG(1) << "MklInputConversionOp: No conversion needed, "
88               << "copying TF inputs to output";
89 
90       ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0);
91       ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1);
92       return;
93     }
94 
95     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
96     // If both inputs are in MKL format
97     if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
98       // It is safer to compare the original TensorFlow shapes than to compare
99       // Mkl shapes since element wise ops are forwarded to Eigen
100       // implementation.
101       TensorShape tf_shape0 = input_shape_0.GetTfShape();
102       TensorShape tf_shape1 = input_shape_1.GetTfShape();
103       TensorShape tensor_shape0 = input_tensor_0.shape();
104       TensorShape tensor_shape1 = input_tensor_1.shape();
105       if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
106         auto input0_md = input_shape_0.GetMklLayout();
107         auto input1_md = input_shape_1.GetMklLayout();
108 
109         // If both have the same shape and same format, pass them through
110         if (input0_md.data.format == input1_md.data.format) {
111           VLOG(1) << "MklInputConversionOp: No conversion needed, "
112                   << "copying MKL inputs with identical shapes to output";
113 
114           ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0);
115           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
116           return;
117         } else {
118           VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
119                      "different, "
120                   << "need to convert to same format";
121           // TODO: For now, input0 is converted and input1 is unchanged
122           //       we should choose the optimal MKL format to convert to.
123           Tensor* tensor_out;
124           MklDnnShape mkl_output_mkl_shape;
125           mkl_output_mkl_shape.SetMklTensor(true);
126           mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
127           mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(),
128                                            input_shape_0.GetSizesAsMklDnnDims(),
129                                            input_shape_0.GetTfDataFormat());
130 
131           // Get MKL layout from input1 as destination layout
132           mkl_output_mkl_shape.SetMklLayout(&input1_md);
133 
134           // Create output Mkl tensor for index 0
135           AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out,
136                                     input_tensor_0.shape(),
137                                     mkl_output_mkl_shape);
138 
139           // Create MklDnnData object for input0 tesnsor
140           auto cpu_engine = engine(engine::cpu, 0);
141           MklDnnData<T> input(&cpu_engine);
142           input.SetUsrMem(input0_md, &input_tensor_0);
143 
144           // Create reorder from input0's layout to input1's layout
145           std::vector<primitive> net;
146           CHECK_EQ(input.CheckReorderToOpMem(
147                        memory::primitive_desc(input1_md, cpu_engine),
148                        tensor_out, &net),
149                    true);
150           stream(stream::kind::eager).submit(net).wait();
151 
152           // Input1 will be passed through
153           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
154           return;
155         }
156       }
157 
158       // Sanity check
159       bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
160                                   (tensor_shape0 == tensor_shape1));
161       if (mkl_shapes_are_same) {
162         CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
163                         "different but MKL shapes are same";
164       }
165 
166       // Both have different shapes, so broadcast will be necessary.
167       // Convert to TF and pass both tensors through (we can't do broadcast
168       // with MKL tensors)
169       VLOG(1) << "MklInputConversionOp: Broadcast needed, "
170               << "converted MKL inputs to TF format";
171       // TODO: Cleanup op_data_type and has_avx512f_ after these two parameters
172       //       are removed from ConvertMklToTf
173       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
174                                            op_data_type, has_avx512f_,
175                                            kInputIndex_0);
176       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
177                                            op_data_type, has_avx512f_,
178                                            kInputIndex_1);
179       SetDummyMklDnnShapeOutput(context, kInputIndex_0);
180       SetDummyMklDnnShapeOutput(context, kInputIndex_1);
181       return;
182     }
183 
184     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
185     // One input is MKL and one is TF. If no broadcast is needed, convert
186     // the TF tensor to MKL, otherwise convert the MKL tensor to TF format
187     VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)";
188 
189     const Tensor* mkl_tensor;
190     const MklDnnShape* mkl_shape;
191     const Tensor* tf_tensor;
192     uint mkl_tensor_index;
193     uint tf_tensor_index;
194     if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
195       mkl_tensor = &input_tensor_0;
196       mkl_shape = &input_shape_0;
197       mkl_tensor_index = 0;
198       tf_tensor = &input_tensor_1;
199       tf_tensor_index = 1;
200     } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
201       mkl_tensor = &input_tensor_1;
202       mkl_shape = &input_shape_1;
203       mkl_tensor_index = 1;
204       tf_tensor = &input_tensor_0;
205       tf_tensor_index = 0;
206     } else {
207       CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
208                       "shapes for MKL "
209                    << "element-wise op";
210     }
211 
212     // Broadcast is needed if the shapes are not the same
213     if (mkl_shape->GetTfShape().num_elements() ==
214         tf_tensor->shape().num_elements()) {
215       // Both shapes are same, convert the TF input to MKL
216       VLOG(1) << "MklInputConversionOp: No broadcast needed.";
217       VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
218               << " to MKL format";
219 
220       // Create MklDnnShape for output Mkl tensor.
221       Tensor* tensor_out;
222       MklDnnShape mkl_output_mkl_shape;
223       mkl_output_mkl_shape.SetMklTensor(true);
224       mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
225       mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(),
226                                        mkl_shape->GetSizesAsMklDnnDims(),
227                                        mkl_shape->GetTfDataFormat());
228       // ** Temporarily borrow the layout from the MKL input **
229       auto output_mkl_md = mkl_shape->GetMklLayout();
230       mkl_output_mkl_shape.SetMklLayout(&output_mkl_md);
231 
232       // Create output Mkl tensor
233       AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out,
234                                 mkl_tensor->shape(), mkl_output_mkl_shape);
235 
236       // Create MklDnnData object for input tensor. Input tensor is in
237       // Tensorflow layout.
238       auto cpu_engine = engine(engine::cpu, 0);
239       MklDnnData<T> tf_input(&cpu_engine);
240       auto input_tf_md = mkl_output_mkl_shape.GetTfLayout();
241       tf_input.SetUsrMem(input_tf_md, tf_tensor);
242 
243       // Create reorder between tensorflow layout and Mkl layout if necessary
244       std::vector<primitive> net;
245       bool reordered = tf_input.CheckReorderToOpMem(
246           memory::primitive_desc(output_mkl_md, cpu_engine), tensor_out, &net);
247 
248       if (!reordered) {
249         // This is the case that the TF tensor has the same shape and format of
250         // mkl tensor. However, tf_tensor can not be simply forwarded to the
251         // output tensor since mkl data tensor is always one dimensional tensor.
252         // Tensor::CopyFrom shares the buffer of the other tensor while set its
253         // shape to the other tensor.
254         CHECK(tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()));
255       } else {
256         stream(stream::kind::eager).submit(net).wait();
257       }
258 
259       // -- The tensor in MKL format passes through --
260       ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
261     } else {
262       // Broadcast is needed, so convert the MKL input to TF
263       VLOG(1) << "MklInputConversionOp: Broadcast needed.";
264       VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index
265               << " to TF format";
266       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
267                                            op_data_type, has_avx512f_,
268                                            mkl_tensor_index);
269       SetDummyMklDnnShapeOutput(context, mkl_tensor_index);
270 
271       // The tensor in TF format passes through
272       ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
273     }
274 
275     VLOG(1) << "MklInputConversionOp: Shapes (output): "
276             << context->mutable_output(kInputIndex_0)->shape().DebugString()
277             << " and "
278             << context->mutable_output(kInputIndex_1)->shape().DebugString();
279 
280     VLOG(1) << "MklInputConversion completed successfully.";
281   }
282 
283  private:
284   /// Data format of the operation
285   string data_format_str;
286 
287   /// Data type of the operation
288   DataType op_data_type;
289 
290   /// CPUIDInfo
291   bool has_avx512f_ = false;
292 };
293 
294 ///////////////////////////////////////////////////////////
295 //               Register kernel
296 ///////////////////////////////////////////////////////////
297 
298 #define REGISTER_CPU(T)                                             \
299   REGISTER_KERNEL_BUILDER(Name("_MklInputConversion")               \
300                               .Device(DEVICE_CPU)                   \
301                               .TypeConstraint<T>("T")               \
302                               .Label(mkl_op_registry::kMklOpLabel), \
303                           MklInputConversionOp<CPUDevice, T>);
304 
305 // TODO(nhasabni): We cannot support all number types since MklDnn does
306 // not support types.
307 // TF_CALL_NUMBER_TYPES(REGISTER_CPU);
308 TF_CALL_float(REGISTER_CPU);
309 #undef REGISTER_CPU
310 }  // namespace tensorflow
311 #endif  // INTEL_MKL
312