1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_ 18 19 #ifdef INTEL_MKL 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "mkldnn.hpp" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/util/mkl_util.h" 28 29 using mkldnn::inner_product_forward; 30 using mkldnn::primitive_attr; 31 using mkldnn::prop_kind; 32 using mkldnn::stream; 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 38 #ifdef INTEL_MKL_DNN_ONLY 39 // Temporarily copying some definitions from mkl_cblas.h so the same code can 40 // be used when calling oneDNN or CBLAS batchmatmul in mkl_batch_matmul_op.cc. 41 typedef enum { CblasRowMajor, CblasColumnMajor } CBLAS_LAYOUT; 42 #define MKL_INT int 43 #endif 44 45 // This structure aggregates multiple inputs to MklDnnMatMul* methods. 46 struct MklDnnMatMulFwdParams { 47 memory::dims src_dims; 48 memory::dims weight_dims; 49 memory::dims bias_dims; 50 memory::dims dst_dims; 51 MEMORY_FORMAT src_format; 52 MEMORY_FORMAT weight_format; 53 MEMORY_FORMAT dst_format; 54 string dtypes = string(""); 55 struct PostOpParam { 56 string name; 57 std::vector<float> param; 58 }; 59 std::vector<PostOpParam> post_op_params; 60 61 MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims, 62 memory::dims bias_dims, memory::dims dst_dims, 63 MEMORY_FORMAT src_format = MEMORY_FORMAT::any, 64 MEMORY_FORMAT weight_format = MEMORY_FORMAT::any, 65 MEMORY_FORMAT dst_format = MEMORY_FORMAT::any) src_dimsMklDnnMatMulFwdParams66 : src_dims(src_dims), 67 weight_dims(weight_dims), 68 bias_dims(bias_dims), 69 dst_dims(dst_dims), 70 src_format(src_format), 71 weight_format(weight_format), 72 dst_format(dst_format) {} 73 }; 74 75 // With quantization, input, weight, bias, and output can have different types. 76 // So we use different template parameters for each type. 77 // TODO(intel-tf): The template type "T" is currently used to match the 78 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h). 79 // In the future, with the removal of "T" from MklPrimitiveFactory, this class 80 // needs to drop "T". 81 template <typename T, typename Tinput, typename Tweight, typename Tbias, 82 typename Toutput> 83 class MklDnnMatMulFwdPrimitive : public MklPrimitive { 84 public: MklDnnMatMulFwdPrimitive(const MklDnnMatMulFwdParams & matmulFwdParams)85 explicit MklDnnMatMulFwdPrimitive( 86 const MklDnnMatMulFwdParams& matmulFwdParams) 87 : MklPrimitive(engine(engine::kind::cpu, 0)) { 88 // Create matmul primitive 89 if (context_.matmul_fwd == nullptr) { 90 Setup(matmulFwdParams); 91 } 92 } 93 ~MklDnnMatMulFwdPrimitive()94 ~MklDnnMatMulFwdPrimitive() {} 95 96 // Inner-product forward execute with bias: 97 // - src_data: input data buffer of src 98 // - weight_data: input data buffer of weight 99 // - bias_data: input data buffer of bias 100 // - dst_data: output data buffer of dst Execute(const Tinput * src_data,const Tweight * weight_data,const Tbias * bias_data,Toutput * dst_data,std::shared_ptr<stream> fwd_stream)101 void Execute(const Tinput* src_data, const Tweight* weight_data, 102 const Tbias* bias_data, Toutput* dst_data, 103 std::shared_ptr<stream> fwd_stream) { 104 #ifdef ENABLE_MKLDNN_THREADPOOL 105 context_.src_mem->set_data_handle( 106 static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream); 107 context_.weight_mem->set_data_handle( 108 static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream); 109 context_.bias_mem->set_data_handle( 110 static_cast<void*>(const_cast<Tbias*>(bias_data))); 111 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), 112 *fwd_stream); 113 #else 114 context_.src_mem->set_data_handle( 115 static_cast<void*>(const_cast<Tinput*>(src_data))); 116 context_.weight_mem->set_data_handle( 117 static_cast<void*>(const_cast<Tweight*>(weight_data))); 118 context_.bias_mem->set_data_handle( 119 static_cast<void*>(const_cast<Tbias*>(bias_data))); 120 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); 121 #endif // ENABLE_MKLDNN_THREADPOOL 122 123 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); 124 125 // After execution, set data handle back 126 context_.src_mem->set_data_handle(DummyData); 127 context_.weight_mem->set_data_handle(DummyData); 128 context_.bias_mem->set_data_handle(DummyData); 129 context_.dst_mem->set_data_handle(DummyData); 130 } 131 132 std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> GetPrimitiveDesc()133 GetPrimitiveDesc() const { 134 return context_.fwd_pd; 135 } 136 137 private: 138 // Primitive reuse context for inner-product Fwd op 139 struct MklDnnMatMulFwdContext { 140 // MKL-DNN memory. 141 std::shared_ptr<mkldnn::memory> src_mem; 142 std::shared_ptr<mkldnn::memory> weight_mem; 143 std::shared_ptr<mkldnn::memory> bias_mem; 144 std::shared_ptr<mkldnn::memory> dst_mem; 145 146 // Descriptor and primitive-descriptor for forward inner-product. 147 std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc; 148 std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd; 149 150 // Memory descriptors. 151 std::shared_ptr<mkldnn::memory::desc> src_md; 152 std::shared_ptr<mkldnn::memory::desc> weight_md; 153 std::shared_ptr<mkldnn::memory::desc> bias_md; 154 std::shared_ptr<mkldnn::memory::desc> dst_md; 155 156 // Inner-product primitive. 157 std::shared_ptr<mkldnn::primitive> matmul_fwd; 158 std::vector<mkldnn::primitive> fwd_primitives; 159 160 std::vector<std::unordered_map<int, memory>> net_args; 161 MklDnnMatMulFwdContextMklDnnMatMulFwdContext162 MklDnnMatMulFwdContext() 163 : src_mem(nullptr), 164 weight_mem(nullptr), 165 bias_mem(nullptr), 166 dst_mem(nullptr), 167 fwd_desc(nullptr), 168 fwd_pd(nullptr), 169 src_md(nullptr), 170 weight_md(nullptr), 171 bias_md(nullptr), 172 dst_md(nullptr), 173 matmul_fwd(nullptr) {} 174 }; 175 Setup(const MklDnnMatMulFwdParams & matmul_fwd_params)176 void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) { 177 // Create memory descriptors for inner-product data without specified 178 // format. 179 context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims}, 180 MklDnnType<Tinput>(), 181 matmul_fwd_params.src_format)); 182 183 context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims}, 184 MklDnnType<Tweight>(), 185 matmul_fwd_params.weight_format)); 186 187 context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, 188 MklDnnType<Toutput>(), 189 matmul_fwd_params.dst_format)); 190 191 context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, 192 MklDnnType<Tbias>(), 193 memory::format_tag::any)); 194 // Create an inner-product. 195 context_.fwd_desc.reset(new inner_product_forward::desc( 196 prop_kind::forward_inference, *context_.src_md, *context_.weight_md, 197 *context_.bias_md, *context_.dst_md)); 198 context_.fwd_pd.reset(new inner_product_forward::primitive_desc( 199 *context_.fwd_desc, cpu_engine_)); 200 201 // Check if there is any fusion as post-ops 202 auto const& post_op_params = matmul_fwd_params.post_op_params; 203 mkldnn::primitive_attr post_ops_attr; 204 mkldnn::post_ops post_ops; 205 if (!post_op_params.empty()) { 206 for (auto const& post_op_param : post_op_params) { 207 if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") { 208 DCHECK_EQ(post_op_param.param.size(), 3); 209 float op_scale = post_op_param.param[0]; 210 float op_alpha = post_op_param.param[1]; 211 float op_beta = post_op_param.param[2]; 212 post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu, 213 op_alpha, op_beta); 214 } else if (post_op_param.name == "relu6") { 215 DCHECK_EQ(post_op_param.param.size(), 3); 216 float op_scale = post_op_param.param[0]; 217 float op_alpha = post_op_param.param[1]; 218 float op_beta = post_op_param.param[2]; 219 post_ops.append_eltwise(op_scale, 220 mkldnn::algorithm::eltwise_bounded_relu, 221 op_alpha, op_beta); 222 } else if (post_op_param.name == "elu") { 223 DCHECK_EQ(post_op_param.param.size(), 3); 224 float op_scale = post_op_param.param[0]; 225 float op_alpha = post_op_param.param[1]; 226 float op_beta = post_op_param.param[2]; 227 post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_elu, 228 op_alpha, op_beta); 229 } else if (post_op_param.name == "tanh") { 230 DCHECK_EQ(post_op_param.param.size(), 3); 231 float op_scale = post_op_param.param[0]; 232 float op_alpha = post_op_param.param[1]; 233 float op_beta = post_op_param.param[2]; 234 post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_tanh, 235 op_alpha, op_beta); 236 } else if (post_op_param.name == "output_scale") { 237 DCHECK_EQ(post_op_param.param.size(), 1); 238 std::vector<float> scales; 239 scales.push_back(post_op_param.param[0]); 240 post_ops_attr.set_output_scales(0, scales); 241 } else if (post_op_param.name == "sum") { 242 DCHECK_EQ(post_op_param.param.size(), 1); 243 float op_scale = post_op_param.param[0]; 244 post_ops.append_sum(op_scale); 245 246 } else { 247 DCHECK((post_op_param.name == "relu") || 248 (post_op_param.name == "relu6") || 249 (post_op_param.name == "elu") || 250 (post_op_param.name == "tanh") || 251 (post_op_param.name == "sum") || 252 (post_op_param.name == "leakyrelu") || 253 (post_op_param.name == "output_scale")); 254 } 255 } 256 post_ops_attr.set_post_ops(post_ops); 257 context_.fwd_pd.reset(new inner_product_forward::primitive_desc( 258 *context_.fwd_desc, post_ops_attr, cpu_engine_)); 259 } else { 260 context_.fwd_pd.reset(new inner_product_forward::primitive_desc( 261 *context_.fwd_desc, cpu_engine_)); 262 } 263 264 // Create memory primitive based on dummy data 265 context_.src_mem.reset( 266 new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData)); 267 context_.weight_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(), 268 cpu_engine_, DummyData)); 269 context_.dst_mem.reset( 270 new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); 271 context_.bias_mem.reset(new memory({{matmul_fwd_params.bias_dims}, 272 MklDnnType<Tbias>(), 273 memory::format_tag::x}, 274 cpu_engine_, DummyData)); 275 276 // Create inner-product primitive. 277 context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd)); 278 context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, 279 {MKLDNN_ARG_WEIGHTS, *context_.weight_mem}, 280 {MKLDNN_ARG_BIAS, *context_.bias_mem}, 281 {MKLDNN_ARG_DST, *context_.dst_mem}}); 282 283 context_.fwd_primitives.push_back(*context_.matmul_fwd); 284 return; 285 } 286 287 struct MklDnnMatMulFwdContext context_; 288 }; 289 290 template <typename T, typename Tinput, typename Tweight, typename Tbias, 291 typename Toutput> 292 class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 293 public: Get(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,bool do_not_cache)294 static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get( 295 const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) { 296 MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd = 297 nullptr; 298 299 if (do_not_cache) { 300 // Always create new primitive 301 matmul_fwd = 302 new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>( 303 mkldnn_matmul_fwd_dims); 304 } else { 305 // Try to find a suitable one in pool 306 matmul_fwd = dynamic_cast< 307 MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>( 308 MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias, 309 Toutput>::GetInstance() 310 .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims)); 311 if (matmul_fwd == nullptr) { 312 matmul_fwd = 313 new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>( 314 mkldnn_matmul_fwd_dims); 315 MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias, 316 Toutput>::GetInstance() 317 .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd); 318 } 319 } 320 return matmul_fwd; 321 } 322 323 private: MklDnnMatMulFwdPrimitiveFactory()324 MklDnnMatMulFwdPrimitiveFactory() {} ~MklDnnMatMulFwdPrimitiveFactory()325 ~MklDnnMatMulFwdPrimitiveFactory() {} 326 GetInstance()327 static MklDnnMatMulFwdPrimitiveFactory& GetInstance() { 328 static MklDnnMatMulFwdPrimitiveFactory instance_; 329 return instance_; 330 } 331 CreateKey(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)332 static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { 333 string prefix = "matmul_fwd_"; 334 FactoryKeyCreator key_creator; 335 key_creator.AddAsKey(prefix); 336 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims); 337 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims); 338 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims); 339 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims); 340 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes); 341 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format); 342 343 // Generate keys for post-ops 344 for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) { 345 if (post_op_param.name == "relu" || post_op_param.name == "relu6" || 346 post_op_param.name == "elu" || post_op_param.name == "tanh" || 347 post_op_param.name == "leakyrelu") { 348 DCHECK_EQ(post_op_param.param.size(), 3); 349 key_creator.AddAsKey(post_op_param.name); 350 key_creator.AddAsKey(post_op_param.param[0]); 351 key_creator.AddAsKey(post_op_param.param[1]); 352 key_creator.AddAsKey(post_op_param.param[2]); 353 } else if (post_op_param.name == "sum") { 354 DCHECK_EQ(post_op_param.param.size(), 1); 355 key_creator.AddAsKey(post_op_param.name); 356 key_creator.AddAsKey(post_op_param.param[0]); 357 } else if (post_op_param.name == "output_scale") { 358 DCHECK_EQ(post_op_param.param.size(), 1); 359 key_creator.AddAsKey(post_op_param.name); 360 key_creator.AddAsKey(post_op_param.param[0]); 361 } else { 362 return string("not_a_key"); 363 } 364 } 365 return key_creator.GetKey(); 366 } 367 GetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)368 MklPrimitive* GetMklDnnMatMulFwd( 369 const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { 370 string key = CreateKey(mkldnn_matmul_fwd_dims); 371 return this->GetOp(key); 372 } 373 SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,MklPrimitive * op)374 void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, 375 MklPrimitive* op) { 376 string key = CreateKey(mkldnn_matmul_fwd_dims); 377 this->SetOp(key, op); 378 } 379 }; 380 381 template <class Tweight, class Toutput> 382 class MklDnnMatMulOpBase : public OpKernel { 383 public: MklDnnMatMulOpBase(OpKernelConstruction * context)384 explicit MklDnnMatMulOpBase(OpKernelConstruction* context) 385 : OpKernel(context) {} 386 void Compute(OpKernelContext* context) override = 0; 387 388 // Allocate output tensor. AllocateOutputTensor(OpKernelContext * context,const inner_product_forward::primitive_desc & mkldnn_matmul_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,Tensor ** output_tensor)389 virtual void AllocateOutputTensor( 390 OpKernelContext* context, 391 const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc, 392 const memory::dims& output_dims_mkl_order, 393 MklTensorFormat output_tf_format, Tensor** output_tensor) { 394 DCHECK(output_tensor); 395 auto dst_pd = mkldnn_matmul_prim_desc.dst_desc(); 396 397 MklDnnShape output_mkl_shape; 398 output_mkl_shape.SetMklTensor(true); 399 output_mkl_shape.SetMklLayout(&dst_pd); 400 output_mkl_shape.SetElemType(MklDnnType<Toutput>()); 401 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 402 output_dims_mkl_order, output_tf_format); 403 404 TensorShape output_tf_shape; 405 output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput))); 406 407 // Allocate Output Tensor 408 AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor, 409 output_tf_shape, output_mkl_shape); 410 } 411 412 // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot 413 // be acquired before entering the function, since it is acquired 414 // inside the function. IsWeightCacheEmpty(OpKernelContext * context)415 inline bool IsWeightCacheEmpty(OpKernelContext* context) 416 TF_LOCKS_EXCLUDED(mu_) { 417 tf_shared_lock lock(mu_); 418 return (weight_oi_.NumElements() == 0); 419 } 420 421 // Cache the converted weight in a persistent tensor. 422 // Only one thread can execute this method at any given time. CacheWeight(OpKernelContext * context,const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> & matmul_fwd_pd,Tweight * weight_data,const Tensor & weight_tensor,MklDnnData<Tweight> & weight,const memory::desc & weight_md)423 void CacheWeight( 424 OpKernelContext* context, 425 const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>& 426 matmul_fwd_pd, 427 Tweight* weight_data, const Tensor& weight_tensor, 428 MklDnnData<Tweight>& weight, const memory::desc& weight_md) 429 TF_LOCKS_EXCLUDED(mu_) { 430 mutex_lock lock(mu_); 431 const Tensor& weight_t = *weight_oi_.AccessTensor(context); 432 433 // If the weights are already cached, there's nothing to do 434 if (weight_t.NumElements() > 0) { 435 return; 436 } 437 438 // reorder and cache the weight 439 weight.SetUsrMem(weight_md, &weight_tensor); 440 weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), cpu_engine_, 441 context); 442 weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle()); 443 444 Tensor* weight_tensor_ptr = nullptr; 445 446 size_t weight_size = matmul_fwd_pd.get()->weights_desc().get_size(); 447 TensorShape weight_tf_shape; 448 weight_tf_shape.AddDim(weight_size / sizeof(Tweight)); 449 450 OP_REQUIRES_OK(context, context->allocate_persistent( 451 DataTypeToEnum<Tweight>::value, weight_tf_shape, 452 &weight_oi_, &weight_tensor_ptr)); 453 454 void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr); 455 memcpy(weight_oi_t_data, weight_data, weight_size); 456 457 // cache the memory descriptor 458 auto expected_md = matmul_fwd_pd->weights_desc(); 459 Tensor* weight_md_tensor_ptr = nullptr; 460 TensorShape weight_mkl_format; 461 weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight)); 462 463 OP_REQUIRES_OK( 464 context, context->allocate_persistent(DataTypeToEnum<Tweight>::value, 465 weight_mkl_format, &weight_oi_md_, 466 &weight_md_tensor_ptr)); 467 *reinterpret_cast<memory::desc*>( 468 weight_md_tensor_ptr->flat<Tweight>().data()) = expected_md; 469 } 470 GetCachedWeight(OpKernelContext * context,const memory::desc & expected_md)471 Tweight* GetCachedWeight(OpKernelContext* context, 472 const memory::desc& expected_md) 473 TF_LOCKS_EXCLUDED(mu_) { 474 tf_shared_lock lock(mu_); 475 const Tensor& weight_t = *weight_oi_.AccessTensor(context); 476 const Tensor& weight_md_t = *weight_oi_md_.AccessTensor(context); 477 478 // Check if the memory descriptor of the cached weight is same as 479 // expected_md. if so use the cached memory, else return NULL 480 if (weight_md_t.flat<Tweight>().size()) { 481 const memory::desc& stored_md = 482 *(static_cast<memory::desc*>(weight_md_t.data())); 483 if (stored_md == expected_md) { 484 return static_cast<Tweight*>( 485 const_cast<Tweight*>(weight_t.flat<Tweight>().data())); 486 } 487 } 488 return nullptr; 489 } 490 491 engine cpu_engine_ = engine(engine::kind::cpu, 0); 492 493 protected: 494 // Tensor to save reordered weight 495 mutex mu_; 496 PersistentTensor weight_oi_ TF_GUARDED_BY(mu_); 497 PersistentTensor weight_oi_md_ TF_GUARDED_BY(mu_); 498 499 bool is_weight_const_; 500 501 const int kInputIndexSrc = 0; 502 const int kInputIndexWeight = 1; 503 const int kInputIndexBias = 2; 504 const int kOutputIndexDst = 0; 505 }; 506 507 using mkldnn::matmul; 508 509 namespace { 510 511 struct MklMatMulParams { 512 memory::dims a_dims; 513 memory::dims b_dims; 514 memory::dims c_dims; 515 memory::dims a_strides; 516 memory::dims b_strides; 517 memory::dims c_strides; 518 MklMatMulParamsMklMatMulParams519 MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims, 520 memory::dims a_strides, memory::dims b_strides, 521 memory::dims c_strides) 522 : a_dims(a_dims), 523 b_dims(b_dims), 524 c_dims(c_dims), 525 a_strides(a_strides), 526 b_strides(b_strides), 527 c_strides(c_strides) {} 528 }; 529 530 template <typename T> 531 class MklMatMulPrimitive : public MklPrimitive { 532 public: MklMatMulPrimitive(const MklMatMulParams & params)533 explicit MklMatMulPrimitive(const MklMatMulParams& params) 534 : MklPrimitive(engine(engine::kind::cpu, 0)) { 535 // Create matmul primitive 536 Setup(params); 537 } 538 ~MklMatMulPrimitive()539 ~MklMatMulPrimitive() {} 540 Execute(const T * a_data,const T * b_data,T * c_data,std::shared_ptr<stream> stream)541 void Execute(const T* a_data, const T* b_data, T* c_data, 542 std::shared_ptr<stream> stream) { 543 #ifdef ENABLE_MKLDNN_THREADPOOL 544 context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)), 545 *stream); 546 context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)), 547 *stream); 548 context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)), 549 *stream); 550 #else 551 context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data))); 552 context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data))); 553 context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data))); 554 #endif // ENABLE_MKLDNN_THREADPOOL 555 execute_primitives(context_.matmul_primitives, stream, context_.net_args); 556 557 // After execution, set data handle back 558 context_.a_mem->set_data_handle(DummyData); 559 context_.b_mem->set_data_handle(DummyData); 560 context_.c_mem->set_data_handle(DummyData); 561 } 562 563 private: 564 // Primitive reuse context for MatMul op 565 struct MklMatMulContext { 566 // MKL-DNN memory. 567 std::shared_ptr<mkldnn::memory> a_mem; 568 std::shared_ptr<mkldnn::memory> b_mem; 569 std::shared_ptr<mkldnn::memory> c_mem; 570 571 // Descriptor and primitive-descriptor for MatMul. 572 std::shared_ptr<matmul::desc> desc; 573 std::shared_ptr<matmul::primitive_desc> prim_desc; 574 575 // Memory descriptors. 576 std::shared_ptr<mkldnn::memory::desc> a_md; 577 std::shared_ptr<mkldnn::memory::desc> b_md; 578 std::shared_ptr<mkldnn::memory::desc> c_md; 579 580 // MatMul primitive. 581 std::vector<mkldnn::primitive> matmul_primitives; 582 std::vector<std::unordered_map<int, memory>> net_args; 583 MklMatMulContextMklMatMulContext584 MklMatMulContext() 585 : a_mem(nullptr), 586 b_mem(nullptr), 587 c_mem(nullptr), 588 desc(nullptr), 589 prim_desc(nullptr), 590 a_md(nullptr), 591 b_md(nullptr), 592 c_md(nullptr) {} 593 }; 594 Setup(const MklMatMulParams & params)595 void Setup(const MklMatMulParams& params) { 596 std::shared_ptr<mkldnn::primitive> matmul_primitive = nullptr; 597 598 // Create MatMul descriptor and primitive descriptor. 599 context_.a_md.reset( 600 new memory::desc({params.a_dims}, MklDnnType<T>(), params.a_strides)); 601 602 context_.b_md.reset( 603 new memory::desc({params.b_dims}, MklDnnType<T>(), params.b_strides)); 604 605 context_.c_md.reset( 606 new memory::desc({params.c_dims}, MklDnnType<T>(), params.c_strides)); 607 608 // Create matmul. 609 context_.desc.reset( 610 new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md)); 611 context_.prim_desc.reset( 612 new matmul::primitive_desc(*context_.desc, cpu_engine_)); 613 614 // Create memory primitive based on dummy data. 615 context_.a_mem.reset( 616 new mkldnn::memory(*context_.a_md, cpu_engine_, DummyData)); 617 context_.b_mem.reset( 618 new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData)); 619 context_.c_mem.reset( 620 new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData)); 621 622 // Create matmul primitive. 623 matmul_primitive.reset(new mkldnn::matmul(*context_.prim_desc)); 624 context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.a_mem}, 625 {MKLDNN_ARG_WEIGHTS, *context_.b_mem}, 626 {MKLDNN_ARG_DST, *context_.c_mem}}); 627 628 context_.matmul_primitives.push_back(*matmul_primitive); 629 return; 630 } 631 632 struct MklMatMulContext context_; 633 }; 634 635 template <typename T> 636 class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> { 637 public: Get(const MklMatMulParams & params,bool do_not_cache)638 static MklMatMulPrimitive<T>* Get(const MklMatMulParams& params, 639 bool do_not_cache) { 640 MklMatMulPrimitive<T>* matmul_prim = nullptr; 641 642 if (do_not_cache) { 643 // Always create new primitive 644 matmul_prim = new MklMatMulPrimitive<T>(params); 645 } else { 646 // Try to find a suitable one in pool 647 matmul_prim = dynamic_cast<MklMatMulPrimitive<T>*>( 648 MklMatMulPrimitiveFactory<T>::GetInstance().GetMklMatMul(params)); 649 if (matmul_prim == nullptr) { 650 matmul_prim = new MklMatMulPrimitive<T>(params); 651 MklMatMulPrimitiveFactory<T>::GetInstance().SetMklMatMul(params, 652 matmul_prim); 653 } 654 } 655 656 return matmul_prim; 657 } 658 659 private: MklMatMulPrimitiveFactory()660 MklMatMulPrimitiveFactory() {} ~MklMatMulPrimitiveFactory()661 ~MklMatMulPrimitiveFactory() {} 662 GetInstance()663 static MklMatMulPrimitiveFactory& GetInstance() { 664 static MklMatMulPrimitiveFactory instance_; 665 return instance_; 666 } 667 CreateKey(const MklMatMulParams & params)668 static string CreateKey(const MklMatMulParams& params) { 669 string prefix = "matmul_"; 670 FactoryKeyCreator key_creator; 671 key_creator.AddAsKey(prefix); 672 key_creator.AddAsKey(params.a_dims); 673 key_creator.AddAsKey(params.b_dims); 674 key_creator.AddAsKey(params.c_dims); 675 key_creator.AddAsKey(params.a_strides); 676 key_creator.AddAsKey(params.b_strides); 677 key_creator.AddAsKey(params.c_strides); 678 key_creator.AddAsKey(typeid(T).name()); 679 680 return key_creator.GetKey(); 681 } 682 GetMklMatMul(const MklMatMulParams & params)683 MklPrimitive* GetMklMatMul(const MklMatMulParams& params) { 684 string key = CreateKey(params); 685 return this->GetOp(key); 686 } 687 SetMklMatMul(const MklMatMulParams & params,MklPrimitive * op)688 void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) { 689 string key = CreateKey(params); 690 this->SetOp(key, op); 691 } 692 }; 693 694 template <typename T> 695 void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, 696 float alpha, const T* a, int64_t lda, const T* b, int64_t ldb, 697 float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) { 698 using dims = mkldnn::memory::dims; 699 700 // Prepare strides based on the transa and transb flags: transposed 701 // matrices have strides swapped 702 dims a_dims = dims{m, k}; 703 dims b_dims = dims{k, n}; 704 dims c_dims = dims{m, n}; 705 dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda}; 706 dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb}; 707 dims c_strides = dims{ldc, 1}; 708 709 // MklMatMul uses const alpha and beta, make guarantee here to ensure 710 // they are never changed. 711 DCHECK_EQ(alpha, 1.0f); 712 DCHECK_EQ(beta, 0.f); 713 714 MklMatMulParams params(a_dims, b_dims, c_dims, a_strides, b_strides, 715 c_strides); 716 MklMatMulPrimitive<T>* matmul_prim = 717 MklMatMulPrimitiveFactory<T>::Get(params, 0); 718 719 // Execute matmul primitive. 720 std::shared_ptr<stream> cpu_stream; 721 cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); 722 matmul_prim->Execute(a, b, c, cpu_stream); 723 } 724 725 } // anonymous namespace 726 727 } // namespace tensorflow 728 729 #endif // INTEL_MKL 730 #endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_ 731