• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef INTEL_MKL
17 
18 #include <algorithm>
19 #include <vector>
20 #include "tensorflow/core/framework/numeric_op.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/platform/byte_order.h"
28 #include "tensorflow/core/platform/cpu_info.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/util/tensor_format.h"
31 
32 #include "mkldnn.hpp"
33 #include "tensorflow/core/kernels/mkl_tfconv_op.h"
34 #include "tensorflow/core/util/mkl_util.h"
35 
36 namespace tensorflow {
37 
38 #ifdef ENABLE_MKLDNN_V1
39 #define ENGINE_CPU engine::kind::cpu
40 #define GET_CHECK_REORDER_TO_OP_MEM_ARGS(md, tensor, net, net_args, engine) \
41   md, tensor, net, net_args, engine
42 #define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat()
43 #define NET_ARGS_PTR &net_args
44 #else
45 #define ENGINE_CPU engine::cpu
46 #define GET_CHECK_REORDER_TO_OP_MEM_ARGS(md, tensor, net_ptr, net_args, \
47                                          engine)                        \
48   memory::primitive_desc(md, engine), tensor, &net_ptr
49 #define GET_TF_DATA_FORMAT(shape, mem_desc) mem_desc.data.format
50 #define NET_ARGS_PTR nullptr
51 #endif  // ENABLE_MKLDNN_V1
52 
53 ///////////////////////////////////////////////////////////
54 //               Op kernel
55 // Checks and ensures that the 2 inputs are compatible for mkl binary ops.
56 // Here's the basic logic:
57 //
58 // if both inputs are in TF format:
59 //   pass the inputs through to the output
60 // else if both inputs are in mkl format:
61 //   if both have the same shape:
62 //     pass the inputs through to the output
63 //   else:
64 //     convert both to TF
65 // else if one is TF and one is MKL:
66 //   if broadcast is needed:
67 //     convert the MKL format input to TF format
68 //   else:
69 //     convert the TF format input to MKL format
70 ///////////////////////////////////////////////////////////
71 
72 template <typename Device, typename T>
73 class MklInputConversionOp : public OpKernel {
74  public:
MklInputConversionOp(OpKernelConstruction * context)75   explicit MklInputConversionOp(OpKernelConstruction* context)
76       : OpKernel(context) {
77     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
78     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
79     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
80   }
81 
82  private:
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {
84     const int kInputIndex_0 = 0, kInputIndex_1 = 1;
85     const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0);
86     MklDnnShape input_shape_0;
87     GetMklShape(context, kInputIndex_0, &input_shape_0);
88 
89     const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1);
90     MklDnnShape input_shape_1;
91     GetMklShape(context, kInputIndex_1, &input_shape_1);
92 
93     VLOG(1) << "MklInputConversionOp: Input shapes are: "
94             << context->input(kInputIndex_0).shape().DebugString() << " and "
95             << context->input(kInputIndex_1).shape().DebugString();
96 
97     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
98     // if both inputs are in TF format, just copy input tensors to output.
99     if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
100       VLOG(1) << "MklInputConversionOp: No conversion needed, "
101               << "copying TF inputs to output";
102 
103       ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0);
104       ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1);
105       return;
106     }
107 
108     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
109     // If both inputs are in MKL format
110     if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
111       // It is safer to compare the original TensorFlow shapes than to compare
112       // Mkl shapes since element wise ops are forwarded to Eigen
113       // implementation.
114       TensorShape tf_shape0 = input_shape_0.GetTfShape();
115       TensorShape tf_shape1 = input_shape_1.GetTfShape();
116       TensorShape tensor_shape0 = input_tensor_0.shape();
117       TensorShape tensor_shape1 = input_tensor_1.shape();
118       if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
119         auto input0_md = input_shape_0.GetMklLayout();
120         auto input1_md = input_shape_1.GetMklLayout();
121 
122         // If both have the same shape and same format, pass them through
123         if (GET_TF_DATA_FORMAT(input_shape_0, input0_md) ==
124             GET_TF_DATA_FORMAT(input_shape_1, input1_md)) {
125           VLOG(1) << "MklInputConversionOp: No conversion needed, "
126                   << "copying MKL inputs with identical shapes to output";
127 
128           ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0);
129           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
130           return;
131         } else {
132           VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
133                      "different, "
134                   << "need to convert to same format";
135           // TODO: For now, input0 is converted and input1 is unchanged
136           //       we should choose the optimal MKL format to convert to.
137           Tensor* tensor_out;
138           MklDnnShape mkl_output_mkl_shape;
139           mkl_output_mkl_shape.SetMklTensor(true);
140           mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
141           mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(),
142                                            input_shape_0.GetSizesAsMklDnnDims(),
143                                            input_shape_0.GetTfDataFormat());
144 
145           // Get MKL layout from input1 as destination layout
146           mkl_output_mkl_shape.SetMklLayout(&input1_md);
147 
148           // Create output Mkl tensor for index 0
149           AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out,
150                                     input_tensor_0.shape(),
151                                     mkl_output_mkl_shape);
152 
153           // Create MklDnnData object for input0 tensor
154           auto cpu_engine = engine(ENGINE_CPU, 0);
155           MklDnnData<T> input(&cpu_engine);
156           input.SetUsrMem(input0_md, &input_tensor_0);
157           // Create reorder from input0's layout to input1's layout
158           std::vector<primitive> net;
159           std::vector<MemoryArgsMap> net_args;
160           // TODO(bhavanis): Refactor CheckReorderToOpMem() to create and
161           // execute reorder
162           OP_REQUIRES(
163               context,
164               input.CheckReorderToOpMem(GET_CHECK_REORDER_TO_OP_MEM_ARGS(
165                   input1_md, tensor_out, net, net_args, cpu_engine)),
166               errors::Internal(
167                   "MklInputConversionOp: Failed to create reorder for input0"));
168           ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
169           // Input1 will be passed through
170           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
171           return;
172         }
173       }
174 
175       // Sanity check
176       bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
177                                   (tensor_shape0 == tensor_shape1));
178       if (mkl_shapes_are_same) {
179         CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
180                         "different but MKL shapes are same";
181       }
182 
183       // Both have different shapes, so broadcast will be necessary.
184       // Convert to TF and pass both tensors through (we can't do broadcast
185       // with MKL tensors)
186       VLOG(1) << "MklInputConversionOp: Broadcast needed, "
187               << "converted MKL inputs to TF format";
188       // TODO: Cleanup op_data_type and has_avx512f_ after these two parameters
189       //       are removed from ConvertMklToTf
190       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
191                                            op_data_type, has_avx512f_,
192                                            kInputIndex_0);
193       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
194                                            op_data_type, has_avx512f_,
195                                            kInputIndex_1);
196       SetDummyMklDnnShapeOutput(context, kInputIndex_0);
197       SetDummyMklDnnShapeOutput(context, kInputIndex_1);
198       return;
199     }
200 
201     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
202     // One input is MKL and one is TF. If no broadcast is needed, convert
203     // the TF tensor to MKL, otherwise convert the MKL tensor to TF format
204     VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)";
205 
206     const Tensor* mkl_tensor;
207     const MklDnnShape* mkl_shape;
208     const Tensor* tf_tensor;
209     uint mkl_tensor_index;
210     uint tf_tensor_index;
211     if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
212       mkl_tensor = &input_tensor_0;
213       mkl_shape = &input_shape_0;
214       mkl_tensor_index = 0;
215       tf_tensor = &input_tensor_1;
216       tf_tensor_index = 1;
217     } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
218       mkl_tensor = &input_tensor_1;
219       mkl_shape = &input_shape_1;
220       mkl_tensor_index = 1;
221       tf_tensor = &input_tensor_0;
222       tf_tensor_index = 0;
223     } else {
224       CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
225                       "shapes for MKL "
226                    << "element-wise op";
227     }
228 
229     // Broadcast is needed if the shapes are not the same
230     if (mkl_shape->GetTfShape().num_elements() ==
231         tf_tensor->shape().num_elements()) {
232       // Both shapes are same, convert the TF input to MKL
233       VLOG(1) << "MklInputConversionOp: No broadcast needed.";
234       VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
235               << " to MKL format";
236 
237       // Create MklDnnShape for output Mkl tensor.
238       Tensor* tensor_out;
239       MklDnnShape mkl_output_mkl_shape;
240       mkl_output_mkl_shape.SetMklTensor(true);
241       mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
242       mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(),
243                                        mkl_shape->GetSizesAsMklDnnDims(),
244                                        mkl_shape->GetTfDataFormat());
245       // ** Temporarily borrow the layout from the MKL input **
246       auto output_mkl_md = mkl_shape->GetMklLayout();
247       mkl_output_mkl_shape.SetMklLayout(&output_mkl_md);
248 
249       // Create output Mkl tensor
250       AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out,
251                                 mkl_tensor->shape(), mkl_output_mkl_shape);
252 
253       // Create MklDnnData object for input tensor. Input tensor is in
254       // Tensorflow layout.
255       auto cpu_engine = engine(ENGINE_CPU, 0);
256       MklDnnData<T> tf_input(&cpu_engine);
257       auto input_tf_md = mkl_output_mkl_shape.GetTfLayout();
258       tf_input.SetUsrMem(input_tf_md, tf_tensor);
259       // Create reorder between TF layout and MKL layout if necessary
260       std::vector<primitive> net;
261       std::vector<MemoryArgsMap> net_args;
262       bool reordered =
263           tf_input.CheckReorderToOpMem(GET_CHECK_REORDER_TO_OP_MEM_ARGS(
264               output_mkl_md, tensor_out, net, net_args, cpu_engine));
265       if (!reordered) {
266         // This is the case that the TF tensor has the same shape and format of
267         // mkl tensor. However, tf_tensor can not be simply forwarded to the
268         // output tensor since mkl data tensor is always one dimensional tensor.
269         // Tensor::CopyFrom shares the buffer of the other tensor while set its
270         // shape to the other tensor.
271         OP_REQUIRES(context,
272                     tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()),
273                     errors::Internal("MklInputConversionOp: Failed to forward "
274                                      "input tensor to output"));
275       } else {
276         ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
277       }
278 
279       // -- The tensor in MKL format passes through --
280       ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
281     } else {
282       // Broadcast is needed, so convert the MKL input to TF
283       VLOG(1) << "MklInputConversionOp: Broadcast needed.";
284       VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index
285               << " to TF format";
286       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
287                                            op_data_type, has_avx512f_,
288                                            mkl_tensor_index);
289       SetDummyMklDnnShapeOutput(context, mkl_tensor_index);
290 
291       // The tensor in TF format passes through
292       ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
293     }
294 
295     VLOG(1) << "MklInputConversionOp: Shapes (output): "
296             << context->mutable_output(kInputIndex_0)->shape().DebugString()
297             << " and "
298             << context->mutable_output(kInputIndex_1)->shape().DebugString();
299 
300     VLOG(1) << "MklInputConversion completed successfully.";
301   }
302 
303  private:
304   /// Data format of the operation
305   string data_format_str;
306 
307   /// Data type of the operation
308   DataType op_data_type;
309 
310   /// CPUIDInfo
311   bool has_avx512f_ = false;
312 };
313 
314 ///////////////////////////////////////////////////////////
315 //               Register kernel
316 ///////////////////////////////////////////////////////////
317 
318 #define REGISTER_CPU(T)                                        \
319   REGISTER_KERNEL_BUILDER(                                     \
320       Name("_MklInputConversion")                              \
321           .Device(DEVICE_CPU)                                  \
322           .TypeConstraint<T>("T")                              \
323           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
324       MklInputConversionOp<CPUDevice, T>);
325 
326 // TODO(nhasabni): We cannot support all number types since MklDnn does
327 // not support types.
328 // TF_CALL_NUMBER_TYPES(REGISTER_CPU);
329 TF_CALL_float(REGISTER_CPU);
330 TF_CALL_bfloat16(REGISTER_CPU);
331 
332 #undef REGISTER_CPU
333 #undef ENGINE_CPU
334 #undef GET_CHECK_REORDER_TO_OP_MEM_ARGS
335 #undef GET_TF_DATA_FORMAT
336 #undef NET_ARGS_PTR
337 
338 }  // namespace tensorflow
339 #endif  // INTEL_MKL
340