1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 18 #include <algorithm> 19 #include <vector> 20 #include "tensorflow/core/framework/numeric_op.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/kernels/ops_util.h" 27 #include "tensorflow/core/platform/byte_order.h" 28 #include "tensorflow/core/platform/cpu_info.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/util/tensor_format.h" 31 32 #include "mkldnn.hpp" 33 #include "tensorflow/core/kernels/mkl_tfconv_op.h" 34 #include "tensorflow/core/util/mkl_util.h" 35 36 namespace tensorflow { 37 38 #ifdef ENABLE_MKLDNN_V1 39 #define ENGINE_CPU engine::kind::cpu 40 #define GET_CHECK_REORDER_TO_OP_MEM_ARGS(md, tensor, net, net_args, engine) \ 41 md, tensor, net, net_args, engine 42 #define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat() 43 #define NET_ARGS_PTR &net_args 44 #else 45 #define ENGINE_CPU engine::cpu 46 #define GET_CHECK_REORDER_TO_OP_MEM_ARGS(md, tensor, net_ptr, net_args, \ 47 engine) \ 48 memory::primitive_desc(md, engine), tensor, &net_ptr 49 #define GET_TF_DATA_FORMAT(shape, mem_desc) mem_desc.data.format 50 #define NET_ARGS_PTR nullptr 51 #endif // ENABLE_MKLDNN_V1 52 53 /////////////////////////////////////////////////////////// 54 // Op kernel 55 // Checks and ensures that the 2 inputs are compatible for mkl binary ops. 56 // Here's the basic logic: 57 // 58 // if both inputs are in TF format: 59 // pass the inputs through to the output 60 // else if both inputs are in mkl format: 61 // if both have the same shape: 62 // pass the inputs through to the output 63 // else: 64 // convert both to TF 65 // else if one is TF and one is MKL: 66 // if broadcast is needed: 67 // convert the MKL format input to TF format 68 // else: 69 // convert the TF format input to MKL format 70 /////////////////////////////////////////////////////////// 71 72 template <typename Device, typename T> 73 class MklInputConversionOp : public OpKernel { 74 public: MklInputConversionOp(OpKernelConstruction * context)75 explicit MklInputConversionOp(OpKernelConstruction* context) 76 : OpKernel(context) { 77 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 78 OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type)); 79 has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F); 80 } 81 82 private: Compute(OpKernelContext * context)83 void Compute(OpKernelContext* context) override { 84 const int kInputIndex_0 = 0, kInputIndex_1 = 1; 85 const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0); 86 MklDnnShape input_shape_0; 87 GetMklShape(context, kInputIndex_0, &input_shape_0); 88 89 const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1); 90 MklDnnShape input_shape_1; 91 GetMklShape(context, kInputIndex_1, &input_shape_1); 92 93 VLOG(1) << "MklInputConversionOp: Input shapes are: " 94 << context->input(kInputIndex_0).shape().DebugString() << " and " 95 << context->input(kInputIndex_1).shape().DebugString(); 96 97 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 98 // if both inputs are in TF format, just copy input tensors to output. 99 if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 100 VLOG(1) << "MklInputConversionOp: No conversion needed, " 101 << "copying TF inputs to output"; 102 103 ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0); 104 ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1); 105 return; 106 } 107 108 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 109 // If both inputs are in MKL format 110 if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 111 // It is safer to compare the original TensorFlow shapes than to compare 112 // Mkl shapes since element wise ops are forwarded to Eigen 113 // implementation. 114 TensorShape tf_shape0 = input_shape_0.GetTfShape(); 115 TensorShape tf_shape1 = input_shape_1.GetTfShape(); 116 TensorShape tensor_shape0 = input_tensor_0.shape(); 117 TensorShape tensor_shape1 = input_tensor_1.shape(); 118 if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) { 119 auto input0_md = input_shape_0.GetMklLayout(); 120 auto input1_md = input_shape_1.GetMklLayout(); 121 122 // If both have the same shape and same format, pass them through 123 if (GET_TF_DATA_FORMAT(input_shape_0, input0_md) == 124 GET_TF_DATA_FORMAT(input_shape_1, input1_md)) { 125 VLOG(1) << "MklInputConversionOp: No conversion needed, " 126 << "copying MKL inputs with identical shapes to output"; 127 128 ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0); 129 ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); 130 return; 131 } else { 132 VLOG(1) << "MklInputConversionOp: Shape is same, but format is " 133 "different, " 134 << "need to convert to same format"; 135 // TODO: For now, input0 is converted and input1 is unchanged 136 // we should choose the optimal MKL format to convert to. 137 Tensor* tensor_out; 138 MklDnnShape mkl_output_mkl_shape; 139 mkl_output_mkl_shape.SetMklTensor(true); 140 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 141 mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(), 142 input_shape_0.GetSizesAsMklDnnDims(), 143 input_shape_0.GetTfDataFormat()); 144 145 // Get MKL layout from input1 as destination layout 146 mkl_output_mkl_shape.SetMklLayout(&input1_md); 147 148 // Create output Mkl tensor for index 0 149 AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out, 150 input_tensor_0.shape(), 151 mkl_output_mkl_shape); 152 153 // Create MklDnnData object for input0 tensor 154 auto cpu_engine = engine(ENGINE_CPU, 0); 155 MklDnnData<T> input(&cpu_engine); 156 input.SetUsrMem(input0_md, &input_tensor_0); 157 // Create reorder from input0's layout to input1's layout 158 std::vector<primitive> net; 159 std::vector<MemoryArgsMap> net_args; 160 // TODO(bhavanis): Refactor CheckReorderToOpMem() to create and 161 // execute reorder 162 OP_REQUIRES( 163 context, 164 input.CheckReorderToOpMem(GET_CHECK_REORDER_TO_OP_MEM_ARGS( 165 input1_md, tensor_out, net, net_args, cpu_engine)), 166 errors::Internal( 167 "MklInputConversionOp: Failed to create reorder for input0")); 168 ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); 169 // Input1 will be passed through 170 ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); 171 return; 172 } 173 } 174 175 // Sanity check 176 bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) && 177 (tensor_shape0 == tensor_shape1)); 178 if (mkl_shapes_are_same) { 179 CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are " 180 "different but MKL shapes are same"; 181 } 182 183 // Both have different shapes, so broadcast will be necessary. 184 // Convert to TF and pass both tensors through (we can't do broadcast 185 // with MKL tensors) 186 VLOG(1) << "MklInputConversionOp: Broadcast needed, " 187 << "converted MKL inputs to TF format"; 188 // TODO: Cleanup op_data_type and has_avx512f_ after these two parameters 189 // are removed from ConvertMklToTf 190 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 191 op_data_type, has_avx512f_, 192 kInputIndex_0); 193 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 194 op_data_type, has_avx512f_, 195 kInputIndex_1); 196 SetDummyMklDnnShapeOutput(context, kInputIndex_0); 197 SetDummyMklDnnShapeOutput(context, kInputIndex_1); 198 return; 199 } 200 201 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 202 // One input is MKL and one is TF. If no broadcast is needed, convert 203 // the TF tensor to MKL, otherwise convert the MKL tensor to TF format 204 VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)"; 205 206 const Tensor* mkl_tensor; 207 const MklDnnShape* mkl_shape; 208 const Tensor* tf_tensor; 209 uint mkl_tensor_index; 210 uint tf_tensor_index; 211 if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 212 mkl_tensor = &input_tensor_0; 213 mkl_shape = &input_shape_0; 214 mkl_tensor_index = 0; 215 tf_tensor = &input_tensor_1; 216 tf_tensor_index = 1; 217 } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 218 mkl_tensor = &input_tensor_1; 219 mkl_shape = &input_shape_1; 220 mkl_tensor_index = 1; 221 tf_tensor = &input_tensor_0; 222 tf_tensor_index = 0; 223 } else { 224 CHECK(false) << "MklInputConversionOp: Unexpected combination of input " 225 "shapes for MKL " 226 << "element-wise op"; 227 } 228 229 // Broadcast is needed if the shapes are not the same 230 if (mkl_shape->GetTfShape().num_elements() == 231 tf_tensor->shape().num_elements()) { 232 // Both shapes are same, convert the TF input to MKL 233 VLOG(1) << "MklInputConversionOp: No broadcast needed."; 234 VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index 235 << " to MKL format"; 236 237 // Create MklDnnShape for output Mkl tensor. 238 Tensor* tensor_out; 239 MklDnnShape mkl_output_mkl_shape; 240 mkl_output_mkl_shape.SetMklTensor(true); 241 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 242 mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(), 243 mkl_shape->GetSizesAsMklDnnDims(), 244 mkl_shape->GetTfDataFormat()); 245 // ** Temporarily borrow the layout from the MKL input ** 246 auto output_mkl_md = mkl_shape->GetMklLayout(); 247 mkl_output_mkl_shape.SetMklLayout(&output_mkl_md); 248 249 // Create output Mkl tensor 250 AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out, 251 mkl_tensor->shape(), mkl_output_mkl_shape); 252 253 // Create MklDnnData object for input tensor. Input tensor is in 254 // Tensorflow layout. 255 auto cpu_engine = engine(ENGINE_CPU, 0); 256 MklDnnData<T> tf_input(&cpu_engine); 257 auto input_tf_md = mkl_output_mkl_shape.GetTfLayout(); 258 tf_input.SetUsrMem(input_tf_md, tf_tensor); 259 // Create reorder between TF layout and MKL layout if necessary 260 std::vector<primitive> net; 261 std::vector<MemoryArgsMap> net_args; 262 bool reordered = 263 tf_input.CheckReorderToOpMem(GET_CHECK_REORDER_TO_OP_MEM_ARGS( 264 output_mkl_md, tensor_out, net, net_args, cpu_engine)); 265 if (!reordered) { 266 // This is the case that the TF tensor has the same shape and format of 267 // mkl tensor. However, tf_tensor can not be simply forwarded to the 268 // output tensor since mkl data tensor is always one dimensional tensor. 269 // Tensor::CopyFrom shares the buffer of the other tensor while set its 270 // shape to the other tensor. 271 OP_REQUIRES(context, 272 tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()), 273 errors::Internal("MklInputConversionOp: Failed to forward " 274 "input tensor to output")); 275 } else { 276 ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); 277 } 278 279 // -- The tensor in MKL format passes through -- 280 ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index); 281 } else { 282 // Broadcast is needed, so convert the MKL input to TF 283 VLOG(1) << "MklInputConversionOp: Broadcast needed."; 284 VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index 285 << " to TF format"; 286 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 287 op_data_type, has_avx512f_, 288 mkl_tensor_index); 289 SetDummyMklDnnShapeOutput(context, mkl_tensor_index); 290 291 // The tensor in TF format passes through 292 ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index); 293 } 294 295 VLOG(1) << "MklInputConversionOp: Shapes (output): " 296 << context->mutable_output(kInputIndex_0)->shape().DebugString() 297 << " and " 298 << context->mutable_output(kInputIndex_1)->shape().DebugString(); 299 300 VLOG(1) << "MklInputConversion completed successfully."; 301 } 302 303 private: 304 /// Data format of the operation 305 string data_format_str; 306 307 /// Data type of the operation 308 DataType op_data_type; 309 310 /// CPUIDInfo 311 bool has_avx512f_ = false; 312 }; 313 314 /////////////////////////////////////////////////////////// 315 // Register kernel 316 /////////////////////////////////////////////////////////// 317 318 #define REGISTER_CPU(T) \ 319 REGISTER_KERNEL_BUILDER( \ 320 Name("_MklInputConversion") \ 321 .Device(DEVICE_CPU) \ 322 .TypeConstraint<T>("T") \ 323 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 324 MklInputConversionOp<CPUDevice, T>); 325 326 // TODO(nhasabni): We cannot support all number types since MklDnn does 327 // not support types. 328 // TF_CALL_NUMBER_TYPES(REGISTER_CPU); 329 TF_CALL_float(REGISTER_CPU); 330 TF_CALL_bfloat16(REGISTER_CPU); 331 332 #undef REGISTER_CPU 333 #undef ENGINE_CPU 334 #undef GET_CHECK_REORDER_TO_OP_MEM_ARGS 335 #undef GET_TF_DATA_FORMAT 336 #undef NET_ARGS_PTR 337 338 } // namespace tensorflow 339 #endif // INTEL_MKL 340