1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 18 #include <algorithm> 19 #include <vector> 20 21 #include "mkldnn.hpp" 22 #include "tensorflow/core/framework/numeric_op.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/kernels/mkl/mkl_tfconv_op.h" 29 #include "tensorflow/core/kernels/ops_util.h" 30 #include "tensorflow/core/platform/byte_order.h" 31 #include "tensorflow/core/platform/cpu_info.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/util/mkl_util.h" 34 #include "tensorflow/core/util/tensor_format.h" 35 36 namespace tensorflow { 37 38 /////////////////////////////////////////////////////////// 39 // Op kernel 40 // Checks and ensures that the 2 inputs are compatible for mkl binary ops. 41 // Here's the basic logic: 42 // 43 // if both inputs are in TF format: 44 // pass the inputs through to the output 45 // else if both inputs are in mkl format: 46 // if both have the same shape: 47 // pass the inputs through to the output 48 // else: 49 // convert both to TF 50 // else if one is TF and one is MKL: 51 // if broadcast is needed: 52 // convert the MKL format input to TF format 53 // else: 54 // convert the TF format input to MKL format 55 /////////////////////////////////////////////////////////// 56 57 template <typename Device, typename T> 58 class MklInputConversionOp : public OpKernel { 59 public: MklInputConversionOp(OpKernelConstruction * context)60 explicit MklInputConversionOp(OpKernelConstruction* context) 61 : OpKernel(context) { 62 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 63 OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type)); 64 has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F); 65 } 66 67 private: Compute(OpKernelContext * context)68 void Compute(OpKernelContext* context) override { 69 const int kInputIndex_0 = 0, kInputIndex_1 = 1; 70 const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0); 71 MklDnnShape input_shape_0; 72 GetMklShape(context, kInputIndex_0, &input_shape_0); 73 74 const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1); 75 MklDnnShape input_shape_1; 76 GetMklShape(context, kInputIndex_1, &input_shape_1); 77 78 VLOG(1) << "MklInputConversionOp: Input shapes are: " 79 << context->input(kInputIndex_0).shape().DebugString() << " and " 80 << context->input(kInputIndex_1).shape().DebugString(); 81 82 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 83 // if both inputs are in TF format, just copy input tensors to output. 84 if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 85 VLOG(1) << "MklInputConversionOp: No conversion needed, " 86 << "copying TF inputs to output"; 87 88 ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0); 89 ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1); 90 return; 91 } 92 93 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 94 // If both inputs are in MKL format 95 if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 96 // It is safer to compare the original TensorFlow shapes than to compare 97 // Mkl shapes since element wise ops are forwarded to Eigen 98 // implementation. 99 TensorShape tf_shape0 = input_shape_0.GetTfShape(); 100 TensorShape tf_shape1 = input_shape_1.GetTfShape(); 101 TensorShape tensor_shape0 = input_tensor_0.shape(); 102 TensorShape tensor_shape1 = input_tensor_1.shape(); 103 if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) { 104 auto input0_md = input_shape_0.GetMklLayout(); 105 auto input1_md = input_shape_1.GetMklLayout(); 106 107 // If both have the same shape and same format, pass them through 108 if (input_shape_0.GetTfDataFormat() == 109 input_shape_1.GetTfDataFormat()) { 110 VLOG(1) << "MklInputConversionOp: No conversion needed, " 111 << "copying MKL inputs with identical shapes to output"; 112 113 ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0); 114 ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); 115 return; 116 } else { 117 VLOG(1) << "MklInputConversionOp: Shape is same, but format is " 118 "different, " 119 << "need to convert to same format"; 120 // TODO: For now, input0 is converted and input1 is unchanged 121 // we should choose the optimal MKL format to convert to. 122 Tensor* tensor_out; 123 MklDnnShape mkl_output_mkl_shape; 124 mkl_output_mkl_shape.SetMklTensor(true); 125 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 126 mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(), 127 input_shape_0.GetSizesAsMklDnnDims(), 128 input_shape_0.GetTfDataFormat()); 129 130 // Get MKL layout from input1 as destination layout 131 mkl_output_mkl_shape.SetMklLayout(&input1_md); 132 133 // Create output Mkl tensor for index 0 134 AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out, 135 input_tensor_0.shape(), 136 mkl_output_mkl_shape); 137 138 // Create MklDnnData object for input0 tensor 139 auto cpu_engine = engine(engine::kind::cpu, 0); 140 MklDnnData<T> input(&cpu_engine); 141 input.SetUsrMem(input0_md, &input_tensor_0); 142 // Create reorder from input0's layout to input1's layout 143 std::vector<primitive> net; 144 std::vector<MemoryArgsMap> net_args; 145 // TODO(bhavanis): Refactor CheckReorderToOpMem() to create and 146 // execute reorder 147 OP_REQUIRES( 148 context, 149 input.CheckReorderToOpMem(input1_md, tensor_out, net, net_args, 150 cpu_engine), 151 errors::Internal( 152 "MklInputConversionOp: Failed to create reorder for input0")); 153 ExecutePrimitive(net, &net_args, cpu_engine, context); 154 // Input1 will be passed through 155 ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); 156 return; 157 } 158 } 159 160 // Sanity check 161 bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) && 162 (tensor_shape0 == tensor_shape1)); 163 if (mkl_shapes_are_same) { 164 CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are " 165 "different but MKL shapes are same"; 166 } 167 168 // Both have different shapes, so broadcast will be necessary. 169 // Convert to TF and pass both tensors through (we can't do broadcast 170 // with MKL tensors) 171 VLOG(1) << "MklInputConversionOp: Broadcast needed, " 172 << "converted MKL inputs to TF format"; 173 // TODO: Cleanup op_data_type and has_avx512f_ after these two parameters 174 // are removed from ConvertMklToTf 175 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 176 op_data_type, has_avx512f_, 177 kInputIndex_0); 178 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 179 op_data_type, has_avx512f_, 180 kInputIndex_1); 181 SetDummyMklDnnShapeOutput(context, kInputIndex_0); 182 SetDummyMklDnnShapeOutput(context, kInputIndex_1); 183 return; 184 } 185 186 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 187 // One input is MKL and one is TF. If no broadcast is needed, convert 188 // the TF tensor to MKL, otherwise convert the MKL tensor to TF format 189 VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)"; 190 191 const Tensor* mkl_tensor; 192 const MklDnnShape* mkl_shape; 193 const Tensor* tf_tensor; 194 uint mkl_tensor_index; 195 uint tf_tensor_index; 196 if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 197 mkl_tensor = &input_tensor_0; 198 mkl_shape = &input_shape_0; 199 mkl_tensor_index = 0; 200 tf_tensor = &input_tensor_1; 201 tf_tensor_index = 1; 202 } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 203 mkl_tensor = &input_tensor_1; 204 mkl_shape = &input_shape_1; 205 mkl_tensor_index = 1; 206 tf_tensor = &input_tensor_0; 207 tf_tensor_index = 0; 208 } else { 209 CHECK(false) << "MklInputConversionOp: Unexpected combination of input " 210 "shapes for MKL " 211 << "element-wise op"; 212 } 213 214 // Broadcast is needed if the shapes are not the same 215 if (mkl_shape->GetTfShape().num_elements() == 216 tf_tensor->shape().num_elements()) { 217 // Both shapes are same, convert the TF input to MKL 218 VLOG(1) << "MklInputConversionOp: No broadcast needed."; 219 VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index 220 << " to MKL format"; 221 222 // Create MklDnnShape for output Mkl tensor. 223 Tensor* tensor_out; 224 MklDnnShape mkl_output_mkl_shape; 225 mkl_output_mkl_shape.SetMklTensor(true); 226 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 227 mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(), 228 mkl_shape->GetSizesAsMklDnnDims(), 229 mkl_shape->GetTfDataFormat()); 230 // ** Temporarily borrow the layout from the MKL input ** 231 auto output_mkl_md = mkl_shape->GetMklLayout(); 232 mkl_output_mkl_shape.SetMklLayout(&output_mkl_md); 233 234 // Create output Mkl tensor 235 AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out, 236 mkl_tensor->shape(), mkl_output_mkl_shape); 237 238 // Create MklDnnData object for input tensor. Input tensor is in 239 // Tensorflow layout. 240 auto cpu_engine = engine(engine::kind::cpu, 0); 241 MklDnnData<T> tf_input(&cpu_engine); 242 auto input_tf_md = mkl_output_mkl_shape.GetTfLayout(); 243 tf_input.SetUsrMem(input_tf_md, tf_tensor); 244 // Create reorder between TF layout and MKL layout if necessary 245 std::vector<primitive> net; 246 std::vector<MemoryArgsMap> net_args; 247 bool reordered = tf_input.CheckReorderToOpMem(output_mkl_md, tensor_out, 248 net, net_args, cpu_engine); 249 if (!reordered) { 250 // This is the case that the TF tensor has the same shape and format of 251 // mkl tensor. However, tf_tensor can not be simply forwarded to the 252 // output tensor since mkl data tensor is always one dimensional tensor. 253 // Tensor::CopyFrom shares the buffer of the other tensor while set its 254 // shape to the other tensor. 255 OP_REQUIRES(context, 256 tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()), 257 errors::Internal("MklInputConversionOp: Failed to forward " 258 "input tensor to output")); 259 } else { 260 ExecutePrimitive(net, &net_args, cpu_engine, context); 261 } 262 263 // -- The tensor in MKL format passes through -- 264 ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index); 265 } else { 266 // Broadcast is needed, so convert the MKL input to TF 267 VLOG(1) << "MklInputConversionOp: Broadcast needed."; 268 VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index 269 << " to TF format"; 270 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 271 op_data_type, has_avx512f_, 272 mkl_tensor_index); 273 SetDummyMklDnnShapeOutput(context, mkl_tensor_index); 274 275 // The tensor in TF format passes through 276 ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index); 277 } 278 279 VLOG(1) << "MklInputConversionOp: Shapes (output): " 280 << context->mutable_output(kInputIndex_0)->shape().DebugString() 281 << " and " 282 << context->mutable_output(kInputIndex_1)->shape().DebugString(); 283 284 VLOG(1) << "MklInputConversion completed successfully."; 285 } 286 287 private: 288 /// Data format of the operation 289 string data_format_str; 290 291 /// Data type of the operation 292 DataType op_data_type; 293 294 /// CPUIDInfo 295 bool has_avx512f_ = false; 296 }; 297 298 /////////////////////////////////////////////////////////// 299 // Register kernel 300 /////////////////////////////////////////////////////////// 301 302 #define REGISTER_CPU(T) \ 303 REGISTER_KERNEL_BUILDER( \ 304 Name("_MklInputConversion") \ 305 .Device(DEVICE_CPU) \ 306 .TypeConstraint<T>("T") \ 307 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 308 MklInputConversionOp<CPUDevice, T>); 309 310 // TODO(nhasabni): We cannot support all number types since MklDnn does 311 // not support types. 312 // TF_CALL_NUMBER_TYPES(REGISTER_CPU); 313 TF_CALL_float(REGISTER_CPU); 314 TF_CALL_bfloat16(REGISTER_CPU); 315 316 #undef REGISTER_CPU 317 318 } // namespace tensorflow 319 #endif // INTEL_MKL 320