• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
18 
19 #ifdef INTEL_MKL
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "mkldnn.hpp"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/util/mkl_util.h"
28 
29 using mkldnn::inner_product_forward;
30 using mkldnn::primitive_attr;
31 using mkldnn::prop_kind;
32 using mkldnn::stream;
33 
34 namespace tensorflow {
35 
36 #define L1_SIZE 32 * 1024
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 
ExecuteSingleThreadedGemm(int m,int n,int k)39 inline bool ExecuteSingleThreadedGemm(int m, int n, int k) {
40   // Ideally we would like to determine blocking and then come up with
41   // a heuristic but what we are targeting are very small models whose
42   // total size is < few L1's. So we will do this simple calculation
43   // to determine if the matrix multiplication should be run on a single thread.
44   constexpr int kHeuristicMultiplier = 8;
45   return ((sizeof(float) * (m * n + k * (m + n))) <
46           L1_SIZE * kHeuristicMultiplier);
47 }
48 
49 // This structure aggregates multiple inputs to MklDnnMatMul* methods.
50 struct MklDnnMatMulFwdParams {
51   memory::dims src_dims;
52   memory::dims weight_dims;
53   memory::dims bias_dims;
54   memory::dims dst_dims;
55   memory::format_tag src_format;
56   memory::format_tag weight_format;
57   memory::format_tag dst_format;
58   string dtypes = string("");
59   struct PostOpParam {
60     string name;
61     std::vector<float> param;
62   };
63   std::vector<PostOpParam> post_op_params;
64 
65   MklDnnMatMulFwdParams(
66       memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
67       memory::dims dst_dims,
68       memory::format_tag src_format = memory::format_tag::any,
69       memory::format_tag weight_format = memory::format_tag::any,
70       memory::format_tag dst_format = memory::format_tag::any)
src_dimsMklDnnMatMulFwdParams71       : src_dims(src_dims),
72         weight_dims(weight_dims),
73         bias_dims(bias_dims),
74         dst_dims(dst_dims),
75         src_format(src_format),
76         weight_format(weight_format),
77         dst_format(dst_format) {}
78 };
79 
80 // With quantization, input, weight, bias, and output can have different types.
81 // So we use different template parameters for each type.
82 // TODO(intel-tf): The template type "T" is currently used to match the
83 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
84 // In the future, with the removal of "T" from MklPrimitiveFactory, this class
85 // needs to drop "T".
86 template <typename T, typename Tinput, typename Tweight, typename Tbias,
87           typename Toutput>
88 class MklDnnMatMulFwdPrimitive : public MklPrimitive {
89  public:
MklDnnMatMulFwdPrimitive(const MklDnnMatMulFwdParams & matmulFwdParams)90   explicit MklDnnMatMulFwdPrimitive(
91       const MklDnnMatMulFwdParams& matmulFwdParams)
92       : MklPrimitive(engine(engine::kind::cpu, 0)) {
93     // Create matmul primitive
94     if (context_.matmul_fwd == nullptr) {
95       Setup(matmulFwdParams);
96     }
97   }
98 
~MklDnnMatMulFwdPrimitive()99   ~MklDnnMatMulFwdPrimitive() {}
100 
101   // Inner-product forward execute with bias:
102   //  - src_data: input data buffer of src
103   //  - weight_data: input data buffer of weight
104   //  - bias_data: input data buffer of bias
105   //  - dst_data: output data buffer of dst
Execute(const Tinput * src_data,const Tweight * weight_data,const Tbias * bias_data,Toutput * dst_data,std::shared_ptr<stream> fwd_stream)106   void Execute(const Tinput* src_data, const Tweight* weight_data,
107                const Tbias* bias_data, Toutput* dst_data,
108                std::shared_ptr<stream> fwd_stream) {
109 #ifndef ENABLE_ONEDNN_OPENMP
110     context_.src_mem->set_data_handle(
111         static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
112     context_.weight_mem->set_data_handle(
113         static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream);
114     context_.bias_mem->set_data_handle(
115         static_cast<void*>(const_cast<Tbias*>(bias_data)));
116     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
117                                       *fwd_stream);
118 #else
119     context_.src_mem->set_data_handle(
120         static_cast<void*>(const_cast<Tinput*>(src_data)));
121     context_.weight_mem->set_data_handle(
122         static_cast<void*>(const_cast<Tweight*>(weight_data)));
123     context_.bias_mem->set_data_handle(
124         static_cast<void*>(const_cast<Tbias*>(bias_data)));
125     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
126 #endif  // !ENABLE_ONEDNN_OPENMP
127 
128     execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
129 
130     // After execution, set data handle back
131     context_.src_mem->set_data_handle(DummyData);
132     context_.weight_mem->set_data_handle(DummyData);
133     context_.bias_mem->set_data_handle(DummyData);
134     context_.dst_mem->set_data_handle(DummyData);
135   }
136 
137   std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
GetPrimitiveDesc()138   GetPrimitiveDesc() const {
139     return context_.fwd_pd;
140   }
141 
142  private:
143   // Primitive reuse context for inner-product Fwd op
144   struct MklDnnMatMulFwdContext {
145     // MKL-DNN memory.
146     std::shared_ptr<mkldnn::memory> src_mem;
147     std::shared_ptr<mkldnn::memory> weight_mem;
148     std::shared_ptr<mkldnn::memory> bias_mem;
149     std::shared_ptr<mkldnn::memory> dst_mem;
150 
151     // Descriptor and primitive-descriptor for forward inner-product.
152     std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc;
153     std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd;
154 
155     // Memory descriptors.
156     std::shared_ptr<mkldnn::memory::desc> src_md;
157     std::shared_ptr<mkldnn::memory::desc> weight_md;
158     std::shared_ptr<mkldnn::memory::desc> bias_md;
159     std::shared_ptr<mkldnn::memory::desc> dst_md;
160 
161     // Inner-product primitive.
162     std::shared_ptr<mkldnn::primitive> matmul_fwd;
163     std::vector<mkldnn::primitive> fwd_primitives;
164 
165     std::vector<std::unordered_map<int, memory>> net_args;
166 
MklDnnMatMulFwdContextMklDnnMatMulFwdContext167     MklDnnMatMulFwdContext()
168         : src_mem(nullptr),
169           weight_mem(nullptr),
170           bias_mem(nullptr),
171           dst_mem(nullptr),
172           fwd_desc(nullptr),
173           fwd_pd(nullptr),
174           src_md(nullptr),
175           weight_md(nullptr),
176           bias_md(nullptr),
177           dst_md(nullptr),
178           matmul_fwd(nullptr) {}
179   };
180 
Setup(const MklDnnMatMulFwdParams & matmul_fwd_params)181   void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
182     // Create memory descriptors for inner-product data without specified
183     // format.
184     context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
185                                            MklDnnType<Tinput>(),
186                                            matmul_fwd_params.src_format));
187 
188     context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
189                                               MklDnnType<Tweight>(),
190                                               matmul_fwd_params.weight_format));
191 
192     context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
193                                            MklDnnType<Toutput>(),
194                                            matmul_fwd_params.dst_format));
195 
196     context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
197                                             MklDnnType<Tbias>(),
198                                             memory::format_tag::any));
199     // Create an inner-product.
200     context_.fwd_desc.reset(new inner_product_forward::desc(
201         prop_kind::forward_inference, *context_.src_md, *context_.weight_md,
202         *context_.bias_md, *context_.dst_md));
203     context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
204         *context_.fwd_desc, cpu_engine_));
205 
206     // Check if there is any fusion as post-ops
207     auto const& post_op_params = matmul_fwd_params.post_op_params;
208     mkldnn::primitive_attr post_ops_attr;
209     mkldnn::post_ops post_ops;
210     if (!post_op_params.empty()) {
211       for (auto const& post_op_param : post_op_params) {
212         if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") {
213           DCHECK_EQ(post_op_param.param.size(), 3);
214           float op_scale = post_op_param.param[0];
215           float op_alpha = post_op_param.param[1];
216           float op_beta = post_op_param.param[2];
217           post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
218                                   op_alpha, op_beta);
219         } else if (post_op_param.name == "relu6") {
220           DCHECK_EQ(post_op_param.param.size(), 3);
221           float op_scale = post_op_param.param[0];
222           float op_alpha = post_op_param.param[1];
223           float op_beta = post_op_param.param[2];
224           post_ops.append_eltwise(op_scale,
225                                   mkldnn::algorithm::eltwise_bounded_relu,
226                                   op_alpha, op_beta);
227         } else if (post_op_param.name == "elu") {
228           DCHECK_EQ(post_op_param.param.size(), 3);
229           float op_scale = post_op_param.param[0];
230           float op_alpha = post_op_param.param[1];
231           float op_beta = post_op_param.param[2];
232           post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_elu,
233                                   op_alpha, op_beta);
234         } else if (post_op_param.name == "tanh") {
235           DCHECK_EQ(post_op_param.param.size(), 3);
236           float op_scale = post_op_param.param[0];
237           float op_alpha = post_op_param.param[1];
238           float op_beta = post_op_param.param[2];
239           post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_tanh,
240                                   op_alpha, op_beta);
241         } else if (post_op_param.name == "logistic") {
242           DCHECK_EQ(post_op_param.param.size(), 3);
243           float op_scale = post_op_param.param[0];
244           float op_alpha = post_op_param.param[1];
245           float op_beta = post_op_param.param[2];
246           post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_logistic,
247                                   op_alpha, op_beta);
248         } else if (post_op_param.name == "output_scale") {
249           DCHECK_EQ(post_op_param.param.size(), 1);
250           std::vector<float> scales;
251           scales.push_back(post_op_param.param[0]);
252           post_ops_attr.set_output_scales(0, scales);
253         } else if (post_op_param.name == "sum") {
254           DCHECK_EQ(post_op_param.param.size(), 1);
255           float op_scale = post_op_param.param[0];
256           post_ops.append_sum(op_scale);
257 
258         } else {
259           DCHECK((post_op_param.name == "relu") ||
260                  (post_op_param.name == "relu6") ||
261                  (post_op_param.name == "elu") ||
262                  (post_op_param.name == "tanh") ||
263                  (post_op_param.name == "logistic") ||
264                  (post_op_param.name == "sum") ||
265                  (post_op_param.name == "leakyrelu") ||
266                  (post_op_param.name == "output_scale"));
267         }
268       }
269       post_ops_attr.set_post_ops(post_ops);
270       context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
271           *context_.fwd_desc, post_ops_attr, cpu_engine_));
272     } else {
273       context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
274           *context_.fwd_desc, cpu_engine_));
275     }
276 
277     // Create memory primitive based on dummy data
278     context_.src_mem.reset(
279         new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
280     context_.weight_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
281                                          cpu_engine_, DummyData));
282     context_.dst_mem.reset(
283         new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
284     context_.bias_mem.reset(new memory({{matmul_fwd_params.bias_dims},
285                                         MklDnnType<Tbias>(),
286                                         memory::format_tag::x},
287                                        cpu_engine_, DummyData));
288 
289     // Create inner-product primitive.
290     context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd));
291     context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
292                                  {MKLDNN_ARG_WEIGHTS, *context_.weight_mem},
293                                  {MKLDNN_ARG_BIAS, *context_.bias_mem},
294                                  {MKLDNN_ARG_DST, *context_.dst_mem}});
295 
296     context_.fwd_primitives.push_back(*context_.matmul_fwd);
297     return;
298   }
299 
300   struct MklDnnMatMulFwdContext context_;
301 };
302 
303 template <typename T, typename Tinput, typename Tweight, typename Tbias,
304           typename Toutput>
305 class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
306  public:
Get(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,bool do_not_cache)307   static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
308       const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
309     MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
310         nullptr;
311 
312     if (do_not_cache) {
313       // Always create new primitive
314       matmul_fwd =
315           new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
316               mkldnn_matmul_fwd_dims);
317     } else {
318       // Try to find a suitable one in pool
319       matmul_fwd = dynamic_cast<
320           MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
321           MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
322                                           Toutput>::GetInstance()
323               .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
324       if (matmul_fwd == nullptr) {
325         matmul_fwd =
326             new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
327                 mkldnn_matmul_fwd_dims);
328         MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
329                                         Toutput>::GetInstance()
330             .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
331       }
332     }
333     return matmul_fwd;
334   }
335 
336  private:
MklDnnMatMulFwdPrimitiveFactory()337   MklDnnMatMulFwdPrimitiveFactory() {}
~MklDnnMatMulFwdPrimitiveFactory()338   ~MklDnnMatMulFwdPrimitiveFactory() {}
339 
GetInstance()340   static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
341     static MklDnnMatMulFwdPrimitiveFactory instance_;
342     return instance_;
343   }
344 
CreateKey(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)345   static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
346     string prefix = "matmul_fwd_";
347     FactoryKeyCreator key_creator;
348     key_creator.AddAsKey(prefix);
349     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
350     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
351     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
352     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
353     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
354     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);
355 
356     // Generate keys for post-ops
357     for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
358       if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
359           post_op_param.name == "elu" || post_op_param.name == "tanh" ||
360           post_op_param.name == "logistic" ||
361           post_op_param.name == "leakyrelu") {
362         DCHECK_EQ(post_op_param.param.size(), 3);
363         key_creator.AddAsKey(post_op_param.name);
364         key_creator.AddAsKey(post_op_param.param[0]);
365         key_creator.AddAsKey(post_op_param.param[1]);
366         key_creator.AddAsKey(post_op_param.param[2]);
367       } else if (post_op_param.name == "sum") {
368         DCHECK_EQ(post_op_param.param.size(), 1);
369         key_creator.AddAsKey(post_op_param.name);
370         key_creator.AddAsKey(post_op_param.param[0]);
371       } else if (post_op_param.name == "output_scale") {
372         DCHECK_EQ(post_op_param.param.size(), 1);
373         key_creator.AddAsKey(post_op_param.name);
374         key_creator.AddAsKey(post_op_param.param[0]);
375       } else {
376         return string("not_a_key");
377       }
378     }
379     return key_creator.GetKey();
380   }
381 
GetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)382   MklPrimitive* GetMklDnnMatMulFwd(
383       const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
384     string key = CreateKey(mkldnn_matmul_fwd_dims);
385     return this->GetOp(key);
386   }
387 
SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,MklPrimitive * op)388   void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
389                           MklPrimitive* op) {
390     string key = CreateKey(mkldnn_matmul_fwd_dims);
391     this->SetOp(key, op);
392   }
393 };
394 
395 template <class Tweight, class Toutput>
396 class MklDnnMatMulOpBase : public OpKernel {
397  public:
MklDnnMatMulOpBase(OpKernelConstruction * context)398   explicit MklDnnMatMulOpBase(OpKernelConstruction* context)
399       : OpKernel(context) {}
400   void Compute(OpKernelContext* context) override = 0;
401 
402   // Allocate output tensor.
403   virtual void AllocateOutputTensor(
404       OpKernelContext* context,
405       const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
406       const memory::dims& output_dims_mkl_order,
407       MklTensorFormat output_tf_format, Tensor** output_tensor,
408       bool native_format = false) {
409     DCHECK(output_tensor);
410     auto dst_pd = mkldnn_matmul_prim_desc.dst_desc();
411 
412     MklDnnShape output_mkl_shape;
413     output_mkl_shape.SetMklTensor(true);
414     output_mkl_shape.SetMklLayout(&dst_pd);
415     output_mkl_shape.SetElemType(MklDnnType<Toutput>());
416     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
417                                  output_dims_mkl_order, output_tf_format);
418 
419     TensorShape output_tf_shape;
420     output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
421 
422     if (native_format) {
423       output_tf_shape = output_mkl_shape.GetTfShape();
424     }
425     // Allocate Output Tensor
426     AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
427                               output_tf_shape, output_mkl_shape, native_format);
428   }
429 
430   // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
431   // be acquired before entering the function, since it is acquired
432   // inside the function.
IsWeightCacheEmpty(OpKernelContext * context)433   inline bool IsWeightCacheEmpty(OpKernelContext* context)
434       TF_LOCKS_EXCLUDED(mu_) {
435     tf_shared_lock lock(mu_);
436     return (weight_oi_.NumElements() == 0);
437   }
438 
439   // Cache the converted weight in a tensor.
440   // Only one thread can execute this method at any given time.
CacheWeight(OpKernelContext * context,const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> & matmul_fwd_pd,Tweight * weight_data,const Tensor & weight_tensor,MklDnnData<Tweight> & weight,const memory::desc & weight_md)441   void CacheWeight(
442       OpKernelContext* context,
443       const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>&
444           matmul_fwd_pd,
445       Tweight* weight_data, const Tensor& weight_tensor,
446       MklDnnData<Tweight>& weight, const memory::desc& weight_md)
447       TF_LOCKS_EXCLUDED(mu_) {
448     mutex_lock lock(mu_);
449     const Tensor& weight_t = weight_oi_;
450 
451     // If the weights are already cached, there's nothing to do
452     if (weight_t.NumElements() > 0) {
453       return;
454     }
455 
456     // reorder and cache the weight
457     weight.SetUsrMem(weight_md, &weight_tensor);
458     weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), cpu_engine_,
459                                context);
460     weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
461 
462     size_t weight_size = matmul_fwd_pd.get()->weights_desc().get_size();
463     TensorShape weight_tf_shape;
464     weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
465 
466     OP_REQUIRES_OK(context,
467                    context->allocate_temp(DataTypeToEnum<Tweight>::value,
468                                           weight_tf_shape, &weight_oi_));
469 
470     void* weight_oi_t_data = weight.GetTensorBuffer(&weight_oi_);
471     memcpy(weight_oi_t_data, weight_data, weight_size);
472 
473     // cache the memory descriptor
474     auto expected_md = matmul_fwd_pd->weights_desc();
475     TensorShape weight_mkl_format;
476     weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
477 
478     OP_REQUIRES_OK(context,
479                    context->allocate_temp(DataTypeToEnum<Tweight>::value,
480                                           weight_mkl_format, &weight_oi_md_));
481     *reinterpret_cast<memory::desc*>(weight_oi_md_.flat<Tweight>().data()) =
482         expected_md;
483   }
484 
GetCachedWeight(OpKernelContext * context,const memory::desc & expected_md)485   Tweight* GetCachedWeight(OpKernelContext* context,
486                            const memory::desc& expected_md)
487       TF_LOCKS_EXCLUDED(mu_) {
488     tf_shared_lock lock(mu_);
489     const Tensor& weight_t = weight_oi_;
490     const Tensor& weight_md_t = weight_oi_md_;
491 
492     // Check if the memory descriptor of the cached weight is same as
493     // expected_md. if so use the cached memory, else return NULL
494     if (weight_md_t.flat<Tweight>().size()) {
495       const memory::desc& stored_md =
496           *(static_cast<memory::desc*>(weight_md_t.data()));
497       if (stored_md == expected_md) {
498         return static_cast<Tweight*>(
499             const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
500       }
501     }
502     return nullptr;
503   }
504 
505   engine cpu_engine_ = engine(engine::kind::cpu, 0);
506 
507  protected:
508   // Tensor to save reordered weight
509   mutex mu_;
510   Tensor weight_oi_ TF_GUARDED_BY(mu_);
511   Tensor weight_oi_md_ TF_GUARDED_BY(mu_);
512 
513   bool is_weight_const_;
514 
515   const int kInputIndexSrc = 0;
516   const int kInputIndexWeight = 1;
517   const int kInputIndexBias = 2;
518   const int kOutputIndexDst = 0;
519 };
520 
521 using mkldnn::matmul;
522 
523 namespace {
524 
525 struct MklMatMulParams {
526   memory::dims a_dims;
527   memory::dims b_dims;
528   memory::dims c_dims;
529   memory::dims a_strides;
530   memory::dims b_strides;
531   memory::dims c_strides;
532 
MklMatMulParamsMklMatMulParams533   MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims,
534                   memory::dims a_strides, memory::dims b_strides,
535                   memory::dims c_strides)
536       : a_dims(a_dims),
537         b_dims(b_dims),
538         c_dims(c_dims),
539         a_strides(a_strides),
540         b_strides(b_strides),
541         c_strides(c_strides) {}
542 };
543 
544 template <typename T>
545 class MklMatMulPrimitive : public MklPrimitive {
546  public:
MklMatMulPrimitive(const MklMatMulParams & params)547   explicit MklMatMulPrimitive(const MklMatMulParams& params)
548       : MklPrimitive(engine(engine::kind::cpu, 0)) {
549     // Create matmul primitive
550     Setup(params);
551   }
552 
~MklMatMulPrimitive()553   ~MklMatMulPrimitive() {}
554 
Execute(const T * a_data,const T * b_data,T * c_data,std::shared_ptr<stream> stream)555   void Execute(const T* a_data, const T* b_data, T* c_data,
556                std::shared_ptr<stream> stream) {
557 #ifndef ENABLE_ONEDNN_OPENMP
558     context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
559                                     *stream);
560     context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
561                                     *stream);
562     context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)),
563                                     *stream);
564 #else
565     context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
566     context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
567     context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
568 #endif  // !ENABLE_ONEDNN_OPENMP
569     execute_primitives(context_.matmul_primitives, stream, context_.net_args);
570 
571     // After execution, set data handle back
572     context_.a_mem->set_data_handle(DummyData);
573     context_.b_mem->set_data_handle(DummyData);
574     context_.c_mem->set_data_handle(DummyData);
575   }
576 
577  private:
578   // Primitive reuse context for MatMul op
579   struct MklMatMulContext {
580     // MKL-DNN memory.
581     std::shared_ptr<mkldnn::memory> a_mem;
582     std::shared_ptr<mkldnn::memory> b_mem;
583     std::shared_ptr<mkldnn::memory> c_mem;
584 
585     // Descriptor and primitive-descriptor for MatMul.
586     std::shared_ptr<matmul::desc> desc;
587     std::shared_ptr<matmul::primitive_desc> prim_desc;
588 
589     // Memory descriptors.
590     std::shared_ptr<mkldnn::memory::desc> a_md;
591     std::shared_ptr<mkldnn::memory::desc> b_md;
592     std::shared_ptr<mkldnn::memory::desc> c_md;
593 
594     // MatMul primitive.
595     std::vector<mkldnn::primitive> matmul_primitives;
596     std::vector<std::unordered_map<int, memory>> net_args;
597 
MklMatMulContextMklMatMulContext598     MklMatMulContext()
599         : a_mem(nullptr),
600           b_mem(nullptr),
601           c_mem(nullptr),
602           desc(nullptr),
603           prim_desc(nullptr),
604           a_md(nullptr),
605           b_md(nullptr),
606           c_md(nullptr) {}
607   };
608 
Setup(const MklMatMulParams & params)609   void Setup(const MklMatMulParams& params) {
610     std::shared_ptr<mkldnn::primitive> matmul_primitive = nullptr;
611 
612     // Create MatMul descriptor and primitive descriptor.
613     context_.a_md.reset(
614         new memory::desc({params.a_dims}, MklDnnType<T>(), params.a_strides));
615 
616     context_.b_md.reset(
617         new memory::desc({params.b_dims}, MklDnnType<T>(), params.b_strides));
618 
619     context_.c_md.reset(
620         new memory::desc({params.c_dims}, MklDnnType<T>(), params.c_strides));
621 
622     // Create matmul.
623     context_.desc.reset(
624         new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
625     context_.prim_desc.reset(
626         new matmul::primitive_desc(*context_.desc, cpu_engine_));
627 
628     // Create memory primitive based on dummy data.
629     context_.a_mem.reset(
630         new mkldnn::memory(*context_.a_md, cpu_engine_, DummyData));
631     context_.b_mem.reset(
632         new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
633     context_.c_mem.reset(
634         new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
635 
636     // Create matmul primitive.
637     matmul_primitive.reset(new mkldnn::matmul(*context_.prim_desc));
638     context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.a_mem},
639                                  {MKLDNN_ARG_WEIGHTS, *context_.b_mem},
640                                  {MKLDNN_ARG_DST, *context_.c_mem}});
641 
642     context_.matmul_primitives.push_back(*matmul_primitive);
643     return;
644   }
645 
646   struct MklMatMulContext context_;
647 };
648 
649 template <typename T>
650 class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
651  public:
Get(const MklMatMulParams & params,bool do_not_cache)652   static MklMatMulPrimitive<T>* Get(const MklMatMulParams& params,
653                                     bool do_not_cache) {
654     MklMatMulPrimitive<T>* matmul_prim = nullptr;
655 
656     if (do_not_cache) {
657       // Always create new primitive
658       matmul_prim = new MklMatMulPrimitive<T>(params);
659     } else {
660       // Try to find a suitable one in pool
661       matmul_prim = dynamic_cast<MklMatMulPrimitive<T>*>(
662           MklMatMulPrimitiveFactory<T>::GetInstance().GetMklMatMul(params));
663       if (matmul_prim == nullptr) {
664         matmul_prim = new MklMatMulPrimitive<T>(params);
665         MklMatMulPrimitiveFactory<T>::GetInstance().SetMklMatMul(params,
666                                                                  matmul_prim);
667       }
668     }
669 
670     return matmul_prim;
671   }
672 
673  private:
MklMatMulPrimitiveFactory()674   MklMatMulPrimitiveFactory() {}
~MklMatMulPrimitiveFactory()675   ~MklMatMulPrimitiveFactory() {}
676 
GetInstance()677   static MklMatMulPrimitiveFactory& GetInstance() {
678     static MklMatMulPrimitiveFactory instance_;
679     return instance_;
680   }
681 
CreateKey(const MklMatMulParams & params)682   static string CreateKey(const MklMatMulParams& params) {
683     string prefix = "matmul_";
684     FactoryKeyCreator key_creator;
685     key_creator.AddAsKey(prefix);
686     key_creator.AddAsKey(params.a_dims);
687     key_creator.AddAsKey(params.b_dims);
688     key_creator.AddAsKey(params.c_dims);
689     key_creator.AddAsKey(params.a_strides);
690     key_creator.AddAsKey(params.b_strides);
691     key_creator.AddAsKey(params.c_strides);
692     key_creator.AddAsKey(typeid(T).name());
693 
694     return key_creator.GetKey();
695   }
696 
GetMklMatMul(const MklMatMulParams & params)697   MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
698     string key = CreateKey(params);
699     return this->GetOp(key);
700   }
701 
SetMklMatMul(const MklMatMulParams & params,MklPrimitive * op)702   void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
703     string key = CreateKey(params);
704     this->SetOp(key, op);
705   }
706 };
707 
708 template <typename T>
709 void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
710                float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
711                float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) {
712   using dims = mkldnn::memory::dims;
713 
714   // Prepare strides based on the transa and transb flags: transposed
715   // matrices have strides swapped
716   dims a_dims = dims{m, k};
717   dims b_dims = dims{k, n};
718   dims c_dims = dims{m, n};
719   dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
720   dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
721   dims c_strides = dims{ldc, 1};
722 
723   // MklMatMul uses const alpha and beta, make guarantee here to ensure
724   // they are never changed.
725   DCHECK_EQ(alpha, 1.0f);
726   DCHECK_EQ(beta, 0.f);
727 
728   MklMatMulParams params(a_dims, b_dims, c_dims, a_strides, b_strides,
729                          c_strides);
730   MklMatMulPrimitive<T>* matmul_prim =
731       MklMatMulPrimitiveFactory<T>::Get(params, 0);
732 
733   // Execute matmul primitive.
734   std::shared_ptr<stream> cpu_stream;
735   MklDnnThreadPool eigen_tp(ctx);
736   cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
737   matmul_prim->Execute(a, b, c, cpu_stream);
738 }
739 
740 }  // anonymous namespace
741 
742 }  // namespace tensorflow
743 
744 #endif  // INTEL_MKL
745 #endif  // TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
746