• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef INTEL_MKL
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "mkldnn.hpp"
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/kernels/mkl/mkl_tfconv_op.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/platform/byte_order.h"
31 #include "tensorflow/core/platform/cpu_info.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/util/mkl_util.h"
34 #include "tensorflow/core/util/tensor_format.h"
35 
36 namespace tensorflow {
37 
38 ///////////////////////////////////////////////////////////
39 //               Op kernel
40 // Checks and ensures that the 2 inputs are compatible for mkl binary ops.
41 // Here's the basic logic:
42 //
43 // if both inputs are in TF format:
44 //   pass the inputs through to the output
45 // else if both inputs are in mkl format:
46 //   if both have the same shape:
47 //     pass the inputs through to the output
48 //   else:
49 //     convert both to TF
50 // else if one is TF and one is MKL:
51 //   if broadcast is needed:
52 //     convert the MKL format input to TF format
53 //   else:
54 //     convert the TF format input to MKL format
55 ///////////////////////////////////////////////////////////
56 
57 template <typename Device, typename T>
58 class MklInputConversionOp : public OpKernel {
59  public:
MklInputConversionOp(OpKernelConstruction * context)60   explicit MklInputConversionOp(OpKernelConstruction* context)
61       : OpKernel(context) {
62     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
63     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
64     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
65   }
66 
67  private:
Compute(OpKernelContext * context)68   void Compute(OpKernelContext* context) override {
69     const int kInputIndex_0 = 0, kInputIndex_1 = 1;
70     const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0);
71     MklDnnShape input_shape_0;
72     GetMklShape(context, kInputIndex_0, &input_shape_0);
73 
74     const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1);
75     MklDnnShape input_shape_1;
76     GetMklShape(context, kInputIndex_1, &input_shape_1);
77 
78     VLOG(1) << "MklInputConversionOp: Input shapes are: "
79             << context->input(kInputIndex_0).shape().DebugString() << " and "
80             << context->input(kInputIndex_1).shape().DebugString();
81 
82     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
83     // if both inputs are in TF format, just copy input tensors to output.
84     if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
85       VLOG(1) << "MklInputConversionOp: No conversion needed, "
86               << "copying TF inputs to output";
87 
88       ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0);
89       ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1);
90       return;
91     }
92 
93     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
94     // If both inputs are in MKL format
95     if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
96       // It is safer to compare the original TensorFlow shapes than to compare
97       // Mkl shapes since element wise ops are forwarded to Eigen
98       // implementation.
99       TensorShape tf_shape0 = input_shape_0.GetTfShape();
100       TensorShape tf_shape1 = input_shape_1.GetTfShape();
101       TensorShape tensor_shape0 = input_tensor_0.shape();
102       TensorShape tensor_shape1 = input_tensor_1.shape();
103       if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
104         auto input0_md = input_shape_0.GetMklLayout();
105         auto input1_md = input_shape_1.GetMklLayout();
106 
107         // If both have the same shape and same format, pass them through
108         if (input_shape_0.GetTfDataFormat() ==
109             input_shape_1.GetTfDataFormat()) {
110           VLOG(1) << "MklInputConversionOp: No conversion needed, "
111                   << "copying MKL inputs with identical shapes to output";
112 
113           ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0);
114           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
115           return;
116         } else {
117           VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
118                      "different, "
119                   << "need to convert to same format";
120           // TODO: For now, input0 is converted and input1 is unchanged
121           //       we should choose the optimal MKL format to convert to.
122           Tensor* tensor_out;
123           MklDnnShape mkl_output_mkl_shape;
124           mkl_output_mkl_shape.SetMklTensor(true);
125           mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
126           mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(),
127                                            input_shape_0.GetSizesAsMklDnnDims(),
128                                            input_shape_0.GetTfDataFormat());
129 
130           // Get MKL layout from input1 as destination layout
131           mkl_output_mkl_shape.SetMklLayout(&input1_md);
132 
133           // Create output Mkl tensor for index 0
134           AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out,
135                                     input_tensor_0.shape(),
136                                     mkl_output_mkl_shape);
137 
138           // Create MklDnnData object for input0 tensor
139           auto cpu_engine = engine(engine::kind::cpu, 0);
140           MklDnnData<T> input(&cpu_engine);
141           input.SetUsrMem(input0_md, &input_tensor_0);
142           // Create reorder from input0's layout to input1's layout
143           std::vector<primitive> net;
144           std::vector<MemoryArgsMap> net_args;
145           // TODO(bhavanis): Refactor CheckReorderToOpMem() to create and
146           // execute reorder
147           OP_REQUIRES(
148               context,
149               input.CheckReorderToOpMem(input1_md, tensor_out, net, net_args,
150                                         cpu_engine),
151               errors::Internal(
152                   "MklInputConversionOp: Failed to create reorder for input0"));
153           ExecutePrimitive(net, &net_args, cpu_engine, context);
154           // Input1 will be passed through
155           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
156           return;
157         }
158       }
159 
160       // Sanity check
161       bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
162                                   (tensor_shape0 == tensor_shape1));
163       if (mkl_shapes_are_same) {
164         CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
165                         "different but MKL shapes are same";
166       }
167 
168       // Both have different shapes, so broadcast will be necessary.
169       // Convert to TF and pass both tensors through (we can't do broadcast
170       // with MKL tensors)
171       VLOG(1) << "MklInputConversionOp: Broadcast needed, "
172               << "converted MKL inputs to TF format";
173       // TODO: Cleanup op_data_type and has_avx512f_ after these two parameters
174       //       are removed from ConvertMklToTf
175       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
176                                            op_data_type, has_avx512f_,
177                                            kInputIndex_0);
178       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
179                                            op_data_type, has_avx512f_,
180                                            kInputIndex_1);
181       SetDummyMklDnnShapeOutput(context, kInputIndex_0);
182       SetDummyMklDnnShapeOutput(context, kInputIndex_1);
183       return;
184     }
185 
186     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
187     // One input is MKL and one is TF. If no broadcast is needed, convert
188     // the TF tensor to MKL, otherwise convert the MKL tensor to TF format
189     VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)";
190 
191     const Tensor* mkl_tensor;
192     const MklDnnShape* mkl_shape;
193     const Tensor* tf_tensor;
194     uint mkl_tensor_index;
195     uint tf_tensor_index;
196     if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
197       mkl_tensor = &input_tensor_0;
198       mkl_shape = &input_shape_0;
199       mkl_tensor_index = 0;
200       tf_tensor = &input_tensor_1;
201       tf_tensor_index = 1;
202     } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
203       mkl_tensor = &input_tensor_1;
204       mkl_shape = &input_shape_1;
205       mkl_tensor_index = 1;
206       tf_tensor = &input_tensor_0;
207       tf_tensor_index = 0;
208     } else {
209       CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
210                       "shapes for MKL "
211                    << "element-wise op";
212     }
213 
214     // Broadcast is needed if the shapes are not the same
215     if (mkl_shape->GetTfShape().num_elements() ==
216         tf_tensor->shape().num_elements()) {
217       // Both shapes are same, convert the TF input to MKL
218       VLOG(1) << "MklInputConversionOp: No broadcast needed.";
219       VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
220               << " to MKL format";
221 
222       // Create MklDnnShape for output Mkl tensor.
223       Tensor* tensor_out;
224       MklDnnShape mkl_output_mkl_shape;
225       mkl_output_mkl_shape.SetMklTensor(true);
226       mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
227       mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(),
228                                        mkl_shape->GetSizesAsMklDnnDims(),
229                                        mkl_shape->GetTfDataFormat());
230       // ** Temporarily borrow the layout from the MKL input **
231       auto output_mkl_md = mkl_shape->GetMklLayout();
232       mkl_output_mkl_shape.SetMklLayout(&output_mkl_md);
233 
234       // Create output Mkl tensor
235       AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out,
236                                 mkl_tensor->shape(), mkl_output_mkl_shape);
237 
238       // Create MklDnnData object for input tensor. Input tensor is in
239       // Tensorflow layout.
240       auto cpu_engine = engine(engine::kind::cpu, 0);
241       MklDnnData<T> tf_input(&cpu_engine);
242       auto input_tf_md = mkl_output_mkl_shape.GetTfLayout();
243       tf_input.SetUsrMem(input_tf_md, tf_tensor);
244       // Create reorder between TF layout and MKL layout if necessary
245       std::vector<primitive> net;
246       std::vector<MemoryArgsMap> net_args;
247       bool reordered = tf_input.CheckReorderToOpMem(output_mkl_md, tensor_out,
248                                                     net, net_args, cpu_engine);
249       if (!reordered) {
250         // This is the case that the TF tensor has the same shape and format of
251         // mkl tensor. However, tf_tensor can not be simply forwarded to the
252         // output tensor since mkl data tensor is always one dimensional tensor.
253         // Tensor::CopyFrom shares the buffer of the other tensor while set its
254         // shape to the other tensor.
255         OP_REQUIRES(context,
256                     tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()),
257                     errors::Internal("MklInputConversionOp: Failed to forward "
258                                      "input tensor to output"));
259       } else {
260         ExecutePrimitive(net, &net_args, cpu_engine, context);
261       }
262 
263       // -- The tensor in MKL format passes through --
264       ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
265     } else {
266       // Broadcast is needed, so convert the MKL input to TF
267       VLOG(1) << "MklInputConversionOp: Broadcast needed.";
268       VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index
269               << " to TF format";
270       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
271                                            op_data_type, has_avx512f_,
272                                            mkl_tensor_index);
273       SetDummyMklDnnShapeOutput(context, mkl_tensor_index);
274 
275       // The tensor in TF format passes through
276       ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
277     }
278 
279     VLOG(1) << "MklInputConversionOp: Shapes (output): "
280             << context->mutable_output(kInputIndex_0)->shape().DebugString()
281             << " and "
282             << context->mutable_output(kInputIndex_1)->shape().DebugString();
283 
284     VLOG(1) << "MklInputConversion completed successfully.";
285   }
286 
287  private:
288   /// Data format of the operation
289   string data_format_str;
290 
291   /// Data type of the operation
292   DataType op_data_type;
293 
294   /// CPUIDInfo
295   bool has_avx512f_ = false;
296 };
297 
298 ///////////////////////////////////////////////////////////
299 //               Register kernel
300 ///////////////////////////////////////////////////////////
301 
302 #define REGISTER_CPU(T)                                        \
303   REGISTER_KERNEL_BUILDER(                                     \
304       Name("_MklInputConversion")                              \
305           .Device(DEVICE_CPU)                                  \
306           .TypeConstraint<T>("T")                              \
307           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
308       MklInputConversionOp<CPUDevice, T>);
309 
310 // TODO(nhasabni): We cannot support all number types since MklDnn does
311 // not support types.
312 // TF_CALL_NUMBER_TYPES(REGISTER_CPU);
313 TF_CALL_float(REGISTER_CPU);
314 TF_CALL_bfloat16(REGISTER_CPU);
315 
316 #undef REGISTER_CPU
317 
318 }  // namespace tensorflow
319 #endif  // INTEL_MKL
320