1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #ifdef INTEL_MKL 19 #define EIGEN_USE_THREADS 20 21 #include <numeric> 22 23 #include "mkldnn.hpp" 24 #include "tensorflow/core/framework/numeric_op.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/lib/gtl/inlined_vector.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/util/mkl_util.h" 29 30 using mkldnn::stream; 31 using mkldnn::sum; 32 33 namespace tensorflow { 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 36 template <typename Device, typename T> 37 class MklAddNOp : public OpKernel { 38 public: ~MklAddNOp()39 ~MklAddNOp() {} MklAddNOp(OpKernelConstruction * context)40 explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {} 41 GetTensorShape(OpKernelContext * ctx,size_t src_index)42 TensorShape GetTensorShape(OpKernelContext* ctx, size_t src_index) { 43 const Tensor& src_tensor = MklGetInput(ctx, src_index); 44 MklDnnShape src_mkl_shape; 45 GetMklShape(ctx, src_index, &src_mkl_shape); 46 return src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() 47 : src_tensor.shape(); 48 } 49 CheckInputShape(OpKernelContext * ctx)50 bool CheckInputShape(OpKernelContext* ctx) { 51 const int num_inputs = ctx->num_inputs() / 2; 52 const TensorShape src0_shape = GetTensorShape(ctx, 0); 53 54 for (size_t i = 1; i < num_inputs; ++i) { 55 if (!src0_shape.IsSameSize(GetTensorShape(ctx, i))) { 56 ctx->SetStatus(errors::InvalidArgument( 57 "Inputs to operation ", this->name(), " of type ", 58 this->type_string(), 59 " must have the same size and shape. Input 0: ", 60 src0_shape.DebugString(), " != input : ", i, 61 GetTensorShape(ctx, i).DebugString())); 62 63 return false; 64 } 65 } 66 67 return true; 68 } 69 70 // Return first tensor index which is in MKL layout, or -1 with no MKL input. FindMKLInputIndex(OpKernelContext * ctx)71 int FindMKLInputIndex(OpKernelContext* ctx) { 72 int mkl_index = -1; 73 const int num_inputs = ctx->num_inputs() / 2; 74 75 MklDnnShape src_mkl_shape; 76 for (size_t i = 0; i < num_inputs; ++i) { 77 GetMklShape(ctx, i, &src_mkl_shape); 78 if (src_mkl_shape.IsMklTensor()) { 79 mkl_index = i; 80 break; 81 } 82 } 83 84 return mkl_index; 85 } 86 ComputeScalar(OpKernelContext * ctx)87 void ComputeScalar(OpKernelContext* ctx) { 88 const int num_inputs = ctx->num_inputs() / 2; 89 const size_t kOutputIdx = 0; 90 TensorShape output_tf_shape; 91 MklDnnShape output_mkl_shape; 92 Tensor* dst_tensor = nullptr; 93 94 T sum = static_cast<T>(0); 95 for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { 96 const Tensor& src_tensor = MklGetInput(ctx, src_idx); 97 T* src_i = const_cast<T*>(src_tensor.flat<T>().data()); 98 sum += src_i[0]; 99 } 100 101 output_mkl_shape.SetMklTensor(false); 102 output_tf_shape = MklGetInput(ctx, kOutputIdx).shape(); 103 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 104 output_mkl_shape); 105 106 T* out_o = dst_tensor->flat<T>().data(); 107 out_o[0] = sum; 108 } 109 Compute(OpKernelContext * ctx)110 void Compute(OpKernelContext* ctx) override { 111 // Each input tensor in MKL layout has additional meta-tensor carrying 112 // layout information. So the number of actual tensors is half the total 113 // number of inputs. 114 const int num_inputs = ctx->num_inputs() / 2; 115 116 MklDnnShape mkl_shape; 117 const size_t kSrc0Idx = 0; 118 const size_t kOutputIdx = 0; 119 120 if (num_inputs == 1) { 121 GetMklShape(ctx, kSrc0Idx, &mkl_shape); 122 bool input_in_mkl_format = mkl_shape.IsMklTensor(); 123 124 if (input_in_mkl_format) { 125 ForwardMklTensorInToOut(ctx, kSrc0Idx, kOutputIdx); 126 } else { 127 ForwardTfTensorInToOut(ctx, kSrc0Idx, kOutputIdx); 128 } 129 return; 130 } 131 132 // Check if the input shape is same 133 if (!CheckInputShape(ctx)) return; 134 135 try { 136 TensorShape output_tf_shape; 137 MklDnnShape output_mkl_shape; 138 const Tensor& src_tensor = MklGetInput(ctx, kSrc0Idx); 139 140 Tensor* dst_tensor = nullptr; 141 142 // Nothing to compute, return. 143 if (src_tensor.shape().num_elements() == 0) { 144 output_mkl_shape.SetMklTensor(false); 145 output_tf_shape = src_tensor.shape(); 146 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 147 output_mkl_shape); 148 return; 149 } 150 151 if (src_tensor.dims() == 0) { 152 ComputeScalar(ctx); 153 return; 154 } 155 156 auto cpu_engine = engine(engine::kind::cpu, 0); 157 std::vector<float> coeff(num_inputs, 1.0); 158 std::vector<memory::desc> srcs_pd; 159 std::vector<memory> inputs; 160 161 MklDnnData<T> dst(&cpu_engine); 162 MklDnnData<T> src(&cpu_engine); 163 bool has_mkl_input = false; 164 int mkl_input_index = FindMKLInputIndex(ctx); 165 MklTensorFormat mkl_data_format; 166 TensorFormat tf_data_format; 167 memory::format_tag dnn_fmt = memory::format_tag::any; 168 if (mkl_input_index >= 0) { 169 has_mkl_input = true; 170 GetMklShape(ctx, mkl_input_index, &mkl_shape); 171 // MKL input has the data format information. 172 mkl_data_format = mkl_shape.GetTfDataFormat(); 173 tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format); 174 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format); 175 } 176 177 std::shared_ptr<stream> fwd_cpu_stream; 178 MklDnnThreadPool eigen_tp(ctx); 179 fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); 180 181 // Create memory descriptor for MKL-DNN. 182 // If all input in Tensorflow format, create block memory descriptor, 183 // else convert TF format to MKL memory descriptor 184 for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { 185 MklDnnShape src_mkl_shape; 186 GetMklShape(ctx, src_idx, &src_mkl_shape); 187 memory::desc md({}, memory::data_type::undef, 188 memory::format_tag::undef); 189 const Tensor& src_tensor = MklGetInput(ctx, src_idx); 190 191 if (src_mkl_shape.IsMklTensor()) { 192 md = src_mkl_shape.GetMklLayout(); 193 } else { 194 if (has_mkl_input) { 195 memory::dims src_dims; 196 if (src_tensor.dims() == 4) { 197 src_dims = 198 TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tf_data_format); 199 } else { 200 DCHECK(src_tensor.dims() == 5); 201 src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), 202 tf_data_format); 203 } 204 md = memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 205 } else { 206 // Create block memory descriptor for TensorFlow format input. 207 auto dims = TFShapeToMklDnnDims(src_tensor.shape()); 208 auto strides = CalculateTFStrides(dims); 209 md = MklDnnData<T>::CreateBlockedMemDesc(dims, strides); 210 } 211 } 212 srcs_pd.push_back(memory::desc(md)); 213 src.SetUsrMem(md, &src_tensor); 214 src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream); 215 inputs.push_back(src.GetOpMem()); 216 } 217 218 auto sum_pd = sum::primitive_desc(coeff, srcs_pd, cpu_engine); 219 output_mkl_shape.SetMklTensor(has_mkl_input); 220 auto output_pd = sum_pd.dst_desc(); 221 dst.SetUsrMem(output_pd); 222 223 if (has_mkl_input) { 224 output_mkl_shape.SetMklLayout(&output_pd); 225 output_mkl_shape.SetElemType(MklDnnType<T>()); 226 output_mkl_shape.SetTfLayout(mkl_shape.GetDimension(), 227 mkl_shape.GetSizesAsMklDnnDims(), 228 mkl_shape.GetTfDataFormat()); 229 output_tf_shape.AddDim((output_pd.get_size() / sizeof(T))); 230 } else { 231 // All inputs have TF shapes, get the shape from first one. 232 output_tf_shape = MklGetInput(ctx, kSrc0Idx).shape(); 233 } 234 AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, 235 output_mkl_shape); 236 dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream); 237 238 // Create Sum op, and submit net for execution. 239 std::vector<primitive> net; 240 mkldnn::sum sum_op(sum_pd); 241 std::unordered_map<int, memory> net_args = { 242 {MKLDNN_ARG_DST, dst.GetOpMem()}}; 243 for (int i = 0; i < num_inputs; ++i) { 244 net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]}); 245 } 246 sum_op.execute(*fwd_cpu_stream, net_args); 247 } catch (mkldnn::error& e) { 248 string error_msg = "Status: " + std::to_string(e.status) + 249 ", message: " + string(e.message) + ", in file " + 250 string(__FILE__) + ":" + std::to_string(__LINE__); 251 OP_REQUIRES_OK( 252 ctx, errors::Aborted("Operation received an exception:", error_msg)); 253 } 254 } 255 }; 256 257 #define REGISTER_MKL_CPU(T) \ 258 REGISTER_KERNEL_BUILDER( \ 259 Name("_MklAddN") \ 260 .Device(DEVICE_CPU) \ 261 .TypeConstraint<T>("T") \ 262 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 263 MklAddNOp<CPUDevice, T>); 264 265 TF_CALL_float(REGISTER_MKL_CPU); 266 TF_CALL_bfloat16(REGISTER_MKL_CPU); 267 #undef REGISTER_MKL_CPU 268 } // namespace tensorflow 269 #endif // INTEL_MKL 270