1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
18
19 #ifdef INTEL_MKL
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "mkldnn.hpp"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/util/mkl_util.h"
28
29 using mkldnn::inner_product_forward;
30 using mkldnn::primitive_attr;
31 using mkldnn::prop_kind;
32 using mkldnn::stream;
33
34 namespace tensorflow {
35
36 #define L1_SIZE 32 * 1024
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38
ExecuteSingleThreadedGemm(int m,int n,int k)39 inline bool ExecuteSingleThreadedGemm(int m, int n, int k) {
40 // Ideally we would like to determine blocking and then come up with
41 // a heuristic but what we are targeting are very small models whose
42 // total size is < few L1's. So we will do this simple calculation
43 // to determine if the matrix multiplication should be run on a single thread.
44 constexpr int kHeuristicMultiplier = 8;
45 return ((sizeof(float) * (m * n + k * (m + n))) <
46 L1_SIZE * kHeuristicMultiplier);
47 }
48
49 // This structure aggregates multiple inputs to MklDnnMatMul* methods.
50 struct MklDnnMatMulFwdParams {
51 memory::dims src_dims;
52 memory::dims weight_dims;
53 memory::dims bias_dims;
54 memory::dims dst_dims;
55 memory::format_tag src_format;
56 memory::format_tag weight_format;
57 memory::format_tag dst_format;
58 string dtypes = string("");
59 struct PostOpParam {
60 string name;
61 std::vector<float> param;
62 };
63 std::vector<PostOpParam> post_op_params;
64
65 MklDnnMatMulFwdParams(
66 memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
67 memory::dims dst_dims,
68 memory::format_tag src_format = memory::format_tag::any,
69 memory::format_tag weight_format = memory::format_tag::any,
70 memory::format_tag dst_format = memory::format_tag::any)
src_dimsMklDnnMatMulFwdParams71 : src_dims(src_dims),
72 weight_dims(weight_dims),
73 bias_dims(bias_dims),
74 dst_dims(dst_dims),
75 src_format(src_format),
76 weight_format(weight_format),
77 dst_format(dst_format) {}
78 };
79
80 // With quantization, input, weight, bias, and output can have different types.
81 // So we use different template parameters for each type.
82 // TODO(intel-tf): The template type "T" is currently used to match the
83 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
84 // In the future, with the removal of "T" from MklPrimitiveFactory, this class
85 // needs to drop "T".
86 template <typename T, typename Tinput, typename Tweight, typename Tbias,
87 typename Toutput>
88 class MklDnnMatMulFwdPrimitive : public MklPrimitive {
89 public:
MklDnnMatMulFwdPrimitive(const MklDnnMatMulFwdParams & matmulFwdParams)90 explicit MklDnnMatMulFwdPrimitive(
91 const MklDnnMatMulFwdParams& matmulFwdParams)
92 : MklPrimitive(engine(engine::kind::cpu, 0)) {
93 // Create matmul primitive
94 if (context_.matmul_fwd == nullptr) {
95 Setup(matmulFwdParams);
96 }
97 }
98
~MklDnnMatMulFwdPrimitive()99 ~MklDnnMatMulFwdPrimitive() {}
100
101 // Inner-product forward execute with bias:
102 // - src_data: input data buffer of src
103 // - weight_data: input data buffer of weight
104 // - bias_data: input data buffer of bias
105 // - dst_data: output data buffer of dst
Execute(const Tinput * src_data,const Tweight * weight_data,const Tbias * bias_data,Toutput * dst_data,std::shared_ptr<stream> fwd_stream)106 void Execute(const Tinput* src_data, const Tweight* weight_data,
107 const Tbias* bias_data, Toutput* dst_data,
108 std::shared_ptr<stream> fwd_stream) {
109 #ifndef ENABLE_ONEDNN_OPENMP
110 context_.src_mem->set_data_handle(
111 static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
112 context_.weight_mem->set_data_handle(
113 static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream);
114 context_.bias_mem->set_data_handle(
115 static_cast<void*>(const_cast<Tbias*>(bias_data)));
116 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
117 *fwd_stream);
118 #else
119 context_.src_mem->set_data_handle(
120 static_cast<void*>(const_cast<Tinput*>(src_data)));
121 context_.weight_mem->set_data_handle(
122 static_cast<void*>(const_cast<Tweight*>(weight_data)));
123 context_.bias_mem->set_data_handle(
124 static_cast<void*>(const_cast<Tbias*>(bias_data)));
125 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
126 #endif // !ENABLE_ONEDNN_OPENMP
127
128 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
129
130 // After execution, set data handle back
131 context_.src_mem->set_data_handle(DummyData);
132 context_.weight_mem->set_data_handle(DummyData);
133 context_.bias_mem->set_data_handle(DummyData);
134 context_.dst_mem->set_data_handle(DummyData);
135 }
136
137 std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
GetPrimitiveDesc()138 GetPrimitiveDesc() const {
139 return context_.fwd_pd;
140 }
141
142 private:
143 // Primitive reuse context for inner-product Fwd op
144 struct MklDnnMatMulFwdContext {
145 // MKL-DNN memory.
146 std::shared_ptr<mkldnn::memory> src_mem;
147 std::shared_ptr<mkldnn::memory> weight_mem;
148 std::shared_ptr<mkldnn::memory> bias_mem;
149 std::shared_ptr<mkldnn::memory> dst_mem;
150
151 // Descriptor and primitive-descriptor for forward inner-product.
152 std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc;
153 std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd;
154
155 // Memory descriptors.
156 std::shared_ptr<mkldnn::memory::desc> src_md;
157 std::shared_ptr<mkldnn::memory::desc> weight_md;
158 std::shared_ptr<mkldnn::memory::desc> bias_md;
159 std::shared_ptr<mkldnn::memory::desc> dst_md;
160
161 // Inner-product primitive.
162 std::shared_ptr<mkldnn::primitive> matmul_fwd;
163 std::vector<mkldnn::primitive> fwd_primitives;
164
165 std::vector<std::unordered_map<int, memory>> net_args;
166
MklDnnMatMulFwdContextMklDnnMatMulFwdContext167 MklDnnMatMulFwdContext()
168 : src_mem(nullptr),
169 weight_mem(nullptr),
170 bias_mem(nullptr),
171 dst_mem(nullptr),
172 fwd_desc(nullptr),
173 fwd_pd(nullptr),
174 src_md(nullptr),
175 weight_md(nullptr),
176 bias_md(nullptr),
177 dst_md(nullptr),
178 matmul_fwd(nullptr) {}
179 };
180
Setup(const MklDnnMatMulFwdParams & matmul_fwd_params)181 void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
182 // Create memory descriptors for inner-product data without specified
183 // format.
184 context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
185 MklDnnType<Tinput>(),
186 matmul_fwd_params.src_format));
187
188 context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
189 MklDnnType<Tweight>(),
190 matmul_fwd_params.weight_format));
191
192 context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
193 MklDnnType<Toutput>(),
194 matmul_fwd_params.dst_format));
195
196 context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
197 MklDnnType<Tbias>(),
198 memory::format_tag::any));
199 // Create an inner-product.
200 context_.fwd_desc.reset(new inner_product_forward::desc(
201 prop_kind::forward_inference, *context_.src_md, *context_.weight_md,
202 *context_.bias_md, *context_.dst_md));
203 context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
204 *context_.fwd_desc, cpu_engine_));
205
206 // Check if there is any fusion as post-ops
207 auto const& post_op_params = matmul_fwd_params.post_op_params;
208 mkldnn::primitive_attr post_ops_attr;
209 mkldnn::post_ops post_ops;
210 if (!post_op_params.empty()) {
211 for (auto const& post_op_param : post_op_params) {
212 if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") {
213 DCHECK_EQ(post_op_param.param.size(), 3);
214 float op_scale = post_op_param.param[0];
215 float op_alpha = post_op_param.param[1];
216 float op_beta = post_op_param.param[2];
217 post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
218 op_alpha, op_beta);
219 } else if (post_op_param.name == "relu6") {
220 DCHECK_EQ(post_op_param.param.size(), 3);
221 float op_scale = post_op_param.param[0];
222 float op_alpha = post_op_param.param[1];
223 float op_beta = post_op_param.param[2];
224 post_ops.append_eltwise(op_scale,
225 mkldnn::algorithm::eltwise_bounded_relu,
226 op_alpha, op_beta);
227 } else if (post_op_param.name == "elu") {
228 DCHECK_EQ(post_op_param.param.size(), 3);
229 float op_scale = post_op_param.param[0];
230 float op_alpha = post_op_param.param[1];
231 float op_beta = post_op_param.param[2];
232 post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_elu,
233 op_alpha, op_beta);
234 } else if (post_op_param.name == "tanh") {
235 DCHECK_EQ(post_op_param.param.size(), 3);
236 float op_scale = post_op_param.param[0];
237 float op_alpha = post_op_param.param[1];
238 float op_beta = post_op_param.param[2];
239 post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_tanh,
240 op_alpha, op_beta);
241 } else if (post_op_param.name == "logistic") {
242 DCHECK_EQ(post_op_param.param.size(), 3);
243 float op_scale = post_op_param.param[0];
244 float op_alpha = post_op_param.param[1];
245 float op_beta = post_op_param.param[2];
246 post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_logistic,
247 op_alpha, op_beta);
248 } else if (post_op_param.name == "output_scale") {
249 DCHECK_EQ(post_op_param.param.size(), 1);
250 std::vector<float> scales;
251 scales.push_back(post_op_param.param[0]);
252 post_ops_attr.set_output_scales(0, scales);
253 } else if (post_op_param.name == "sum") {
254 DCHECK_EQ(post_op_param.param.size(), 1);
255 float op_scale = post_op_param.param[0];
256 post_ops.append_sum(op_scale);
257
258 } else {
259 DCHECK((post_op_param.name == "relu") ||
260 (post_op_param.name == "relu6") ||
261 (post_op_param.name == "elu") ||
262 (post_op_param.name == "tanh") ||
263 (post_op_param.name == "logistic") ||
264 (post_op_param.name == "sum") ||
265 (post_op_param.name == "leakyrelu") ||
266 (post_op_param.name == "output_scale"));
267 }
268 }
269 post_ops_attr.set_post_ops(post_ops);
270 context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
271 *context_.fwd_desc, post_ops_attr, cpu_engine_));
272 } else {
273 context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
274 *context_.fwd_desc, cpu_engine_));
275 }
276
277 // Create memory primitive based on dummy data
278 context_.src_mem.reset(
279 new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
280 context_.weight_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
281 cpu_engine_, DummyData));
282 context_.dst_mem.reset(
283 new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
284 context_.bias_mem.reset(new memory({{matmul_fwd_params.bias_dims},
285 MklDnnType<Tbias>(),
286 memory::format_tag::x},
287 cpu_engine_, DummyData));
288
289 // Create inner-product primitive.
290 context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd));
291 context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
292 {MKLDNN_ARG_WEIGHTS, *context_.weight_mem},
293 {MKLDNN_ARG_BIAS, *context_.bias_mem},
294 {MKLDNN_ARG_DST, *context_.dst_mem}});
295
296 context_.fwd_primitives.push_back(*context_.matmul_fwd);
297 return;
298 }
299
300 struct MklDnnMatMulFwdContext context_;
301 };
302
303 template <typename T, typename Tinput, typename Tweight, typename Tbias,
304 typename Toutput>
305 class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
306 public:
Get(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,bool do_not_cache)307 static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
308 const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
309 MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
310 nullptr;
311
312 if (do_not_cache) {
313 // Always create new primitive
314 matmul_fwd =
315 new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
316 mkldnn_matmul_fwd_dims);
317 } else {
318 // Try to find a suitable one in pool
319 matmul_fwd = dynamic_cast<
320 MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
321 MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
322 Toutput>::GetInstance()
323 .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
324 if (matmul_fwd == nullptr) {
325 matmul_fwd =
326 new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
327 mkldnn_matmul_fwd_dims);
328 MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
329 Toutput>::GetInstance()
330 .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
331 }
332 }
333 return matmul_fwd;
334 }
335
336 private:
MklDnnMatMulFwdPrimitiveFactory()337 MklDnnMatMulFwdPrimitiveFactory() {}
~MklDnnMatMulFwdPrimitiveFactory()338 ~MklDnnMatMulFwdPrimitiveFactory() {}
339
GetInstance()340 static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
341 static MklDnnMatMulFwdPrimitiveFactory instance_;
342 return instance_;
343 }
344
CreateKey(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)345 static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
346 string prefix = "matmul_fwd_";
347 FactoryKeyCreator key_creator;
348 key_creator.AddAsKey(prefix);
349 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
350 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
351 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
352 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
353 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
354 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);
355
356 // Generate keys for post-ops
357 for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
358 if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
359 post_op_param.name == "elu" || post_op_param.name == "tanh" ||
360 post_op_param.name == "logistic" ||
361 post_op_param.name == "leakyrelu") {
362 DCHECK_EQ(post_op_param.param.size(), 3);
363 key_creator.AddAsKey(post_op_param.name);
364 key_creator.AddAsKey(post_op_param.param[0]);
365 key_creator.AddAsKey(post_op_param.param[1]);
366 key_creator.AddAsKey(post_op_param.param[2]);
367 } else if (post_op_param.name == "sum") {
368 DCHECK_EQ(post_op_param.param.size(), 1);
369 key_creator.AddAsKey(post_op_param.name);
370 key_creator.AddAsKey(post_op_param.param[0]);
371 } else if (post_op_param.name == "output_scale") {
372 DCHECK_EQ(post_op_param.param.size(), 1);
373 key_creator.AddAsKey(post_op_param.name);
374 key_creator.AddAsKey(post_op_param.param[0]);
375 } else {
376 return string("not_a_key");
377 }
378 }
379 return key_creator.GetKey();
380 }
381
GetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)382 MklPrimitive* GetMklDnnMatMulFwd(
383 const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
384 string key = CreateKey(mkldnn_matmul_fwd_dims);
385 return this->GetOp(key);
386 }
387
SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,MklPrimitive * op)388 void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
389 MklPrimitive* op) {
390 string key = CreateKey(mkldnn_matmul_fwd_dims);
391 this->SetOp(key, op);
392 }
393 };
394
395 template <class Tweight, class Toutput>
396 class MklDnnMatMulOpBase : public OpKernel {
397 public:
MklDnnMatMulOpBase(OpKernelConstruction * context)398 explicit MklDnnMatMulOpBase(OpKernelConstruction* context)
399 : OpKernel(context) {}
400 void Compute(OpKernelContext* context) override = 0;
401
402 // Allocate output tensor.
403 virtual void AllocateOutputTensor(
404 OpKernelContext* context,
405 const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
406 const memory::dims& output_dims_mkl_order,
407 MklTensorFormat output_tf_format, Tensor** output_tensor,
408 bool native_format = false) {
409 DCHECK(output_tensor);
410 auto dst_pd = mkldnn_matmul_prim_desc.dst_desc();
411
412 MklDnnShape output_mkl_shape;
413 output_mkl_shape.SetMklTensor(true);
414 output_mkl_shape.SetMklLayout(&dst_pd);
415 output_mkl_shape.SetElemType(MklDnnType<Toutput>());
416 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
417 output_dims_mkl_order, output_tf_format);
418
419 TensorShape output_tf_shape;
420 output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
421
422 if (native_format) {
423 output_tf_shape = output_mkl_shape.GetTfShape();
424 }
425 // Allocate Output Tensor
426 AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
427 output_tf_shape, output_mkl_shape, native_format);
428 }
429
430 // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
431 // be acquired before entering the function, since it is acquired
432 // inside the function.
IsWeightCacheEmpty(OpKernelContext * context)433 inline bool IsWeightCacheEmpty(OpKernelContext* context)
434 TF_LOCKS_EXCLUDED(mu_) {
435 tf_shared_lock lock(mu_);
436 return (weight_oi_.NumElements() == 0);
437 }
438
439 // Cache the converted weight in a tensor.
440 // Only one thread can execute this method at any given time.
CacheWeight(OpKernelContext * context,const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> & matmul_fwd_pd,Tweight * weight_data,const Tensor & weight_tensor,MklDnnData<Tweight> & weight,const memory::desc & weight_md)441 void CacheWeight(
442 OpKernelContext* context,
443 const std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>&
444 matmul_fwd_pd,
445 Tweight* weight_data, const Tensor& weight_tensor,
446 MklDnnData<Tweight>& weight, const memory::desc& weight_md)
447 TF_LOCKS_EXCLUDED(mu_) {
448 mutex_lock lock(mu_);
449 const Tensor& weight_t = weight_oi_;
450
451 // If the weights are already cached, there's nothing to do
452 if (weight_t.NumElements() > 0) {
453 return;
454 }
455
456 // reorder and cache the weight
457 weight.SetUsrMem(weight_md, &weight_tensor);
458 weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), cpu_engine_,
459 context);
460 weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
461
462 size_t weight_size = matmul_fwd_pd.get()->weights_desc().get_size();
463 TensorShape weight_tf_shape;
464 weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
465
466 OP_REQUIRES_OK(context,
467 context->allocate_temp(DataTypeToEnum<Tweight>::value,
468 weight_tf_shape, &weight_oi_));
469
470 void* weight_oi_t_data = weight.GetTensorBuffer(&weight_oi_);
471 memcpy(weight_oi_t_data, weight_data, weight_size);
472
473 // cache the memory descriptor
474 auto expected_md = matmul_fwd_pd->weights_desc();
475 TensorShape weight_mkl_format;
476 weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
477
478 OP_REQUIRES_OK(context,
479 context->allocate_temp(DataTypeToEnum<Tweight>::value,
480 weight_mkl_format, &weight_oi_md_));
481 *reinterpret_cast<memory::desc*>(weight_oi_md_.flat<Tweight>().data()) =
482 expected_md;
483 }
484
GetCachedWeight(OpKernelContext * context,const memory::desc & expected_md)485 Tweight* GetCachedWeight(OpKernelContext* context,
486 const memory::desc& expected_md)
487 TF_LOCKS_EXCLUDED(mu_) {
488 tf_shared_lock lock(mu_);
489 const Tensor& weight_t = weight_oi_;
490 const Tensor& weight_md_t = weight_oi_md_;
491
492 // Check if the memory descriptor of the cached weight is same as
493 // expected_md. if so use the cached memory, else return NULL
494 if (weight_md_t.flat<Tweight>().size()) {
495 const memory::desc& stored_md =
496 *(static_cast<memory::desc*>(weight_md_t.data()));
497 if (stored_md == expected_md) {
498 return static_cast<Tweight*>(
499 const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
500 }
501 }
502 return nullptr;
503 }
504
505 engine cpu_engine_ = engine(engine::kind::cpu, 0);
506
507 protected:
508 // Tensor to save reordered weight
509 mutex mu_;
510 Tensor weight_oi_ TF_GUARDED_BY(mu_);
511 Tensor weight_oi_md_ TF_GUARDED_BY(mu_);
512
513 bool is_weight_const_;
514
515 const int kInputIndexSrc = 0;
516 const int kInputIndexWeight = 1;
517 const int kInputIndexBias = 2;
518 const int kOutputIndexDst = 0;
519 };
520
521 using mkldnn::matmul;
522
523 namespace {
524
525 struct MklMatMulParams {
526 memory::dims a_dims;
527 memory::dims b_dims;
528 memory::dims c_dims;
529 memory::dims a_strides;
530 memory::dims b_strides;
531 memory::dims c_strides;
532
MklMatMulParamsMklMatMulParams533 MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims,
534 memory::dims a_strides, memory::dims b_strides,
535 memory::dims c_strides)
536 : a_dims(a_dims),
537 b_dims(b_dims),
538 c_dims(c_dims),
539 a_strides(a_strides),
540 b_strides(b_strides),
541 c_strides(c_strides) {}
542 };
543
544 template <typename T>
545 class MklMatMulPrimitive : public MklPrimitive {
546 public:
MklMatMulPrimitive(const MklMatMulParams & params)547 explicit MklMatMulPrimitive(const MklMatMulParams& params)
548 : MklPrimitive(engine(engine::kind::cpu, 0)) {
549 // Create matmul primitive
550 Setup(params);
551 }
552
~MklMatMulPrimitive()553 ~MklMatMulPrimitive() {}
554
Execute(const T * a_data,const T * b_data,T * c_data,std::shared_ptr<stream> stream)555 void Execute(const T* a_data, const T* b_data, T* c_data,
556 std::shared_ptr<stream> stream) {
557 #ifndef ENABLE_ONEDNN_OPENMP
558 context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
559 *stream);
560 context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
561 *stream);
562 context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)),
563 *stream);
564 #else
565 context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
566 context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
567 context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
568 #endif // !ENABLE_ONEDNN_OPENMP
569 execute_primitives(context_.matmul_primitives, stream, context_.net_args);
570
571 // After execution, set data handle back
572 context_.a_mem->set_data_handle(DummyData);
573 context_.b_mem->set_data_handle(DummyData);
574 context_.c_mem->set_data_handle(DummyData);
575 }
576
577 private:
578 // Primitive reuse context for MatMul op
579 struct MklMatMulContext {
580 // MKL-DNN memory.
581 std::shared_ptr<mkldnn::memory> a_mem;
582 std::shared_ptr<mkldnn::memory> b_mem;
583 std::shared_ptr<mkldnn::memory> c_mem;
584
585 // Descriptor and primitive-descriptor for MatMul.
586 std::shared_ptr<matmul::desc> desc;
587 std::shared_ptr<matmul::primitive_desc> prim_desc;
588
589 // Memory descriptors.
590 std::shared_ptr<mkldnn::memory::desc> a_md;
591 std::shared_ptr<mkldnn::memory::desc> b_md;
592 std::shared_ptr<mkldnn::memory::desc> c_md;
593
594 // MatMul primitive.
595 std::vector<mkldnn::primitive> matmul_primitives;
596 std::vector<std::unordered_map<int, memory>> net_args;
597
MklMatMulContextMklMatMulContext598 MklMatMulContext()
599 : a_mem(nullptr),
600 b_mem(nullptr),
601 c_mem(nullptr),
602 desc(nullptr),
603 prim_desc(nullptr),
604 a_md(nullptr),
605 b_md(nullptr),
606 c_md(nullptr) {}
607 };
608
Setup(const MklMatMulParams & params)609 void Setup(const MklMatMulParams& params) {
610 std::shared_ptr<mkldnn::primitive> matmul_primitive = nullptr;
611
612 // Create MatMul descriptor and primitive descriptor.
613 context_.a_md.reset(
614 new memory::desc({params.a_dims}, MklDnnType<T>(), params.a_strides));
615
616 context_.b_md.reset(
617 new memory::desc({params.b_dims}, MklDnnType<T>(), params.b_strides));
618
619 context_.c_md.reset(
620 new memory::desc({params.c_dims}, MklDnnType<T>(), params.c_strides));
621
622 // Create matmul.
623 context_.desc.reset(
624 new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
625 context_.prim_desc.reset(
626 new matmul::primitive_desc(*context_.desc, cpu_engine_));
627
628 // Create memory primitive based on dummy data.
629 context_.a_mem.reset(
630 new mkldnn::memory(*context_.a_md, cpu_engine_, DummyData));
631 context_.b_mem.reset(
632 new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
633 context_.c_mem.reset(
634 new mkldnn::memory(*context_.b_md, cpu_engine_, DummyData));
635
636 // Create matmul primitive.
637 matmul_primitive.reset(new mkldnn::matmul(*context_.prim_desc));
638 context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.a_mem},
639 {MKLDNN_ARG_WEIGHTS, *context_.b_mem},
640 {MKLDNN_ARG_DST, *context_.c_mem}});
641
642 context_.matmul_primitives.push_back(*matmul_primitive);
643 return;
644 }
645
646 struct MklMatMulContext context_;
647 };
648
649 template <typename T>
650 class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
651 public:
Get(const MklMatMulParams & params,bool do_not_cache)652 static MklMatMulPrimitive<T>* Get(const MklMatMulParams& params,
653 bool do_not_cache) {
654 MklMatMulPrimitive<T>* matmul_prim = nullptr;
655
656 if (do_not_cache) {
657 // Always create new primitive
658 matmul_prim = new MklMatMulPrimitive<T>(params);
659 } else {
660 // Try to find a suitable one in pool
661 matmul_prim = dynamic_cast<MklMatMulPrimitive<T>*>(
662 MklMatMulPrimitiveFactory<T>::GetInstance().GetMklMatMul(params));
663 if (matmul_prim == nullptr) {
664 matmul_prim = new MklMatMulPrimitive<T>(params);
665 MklMatMulPrimitiveFactory<T>::GetInstance().SetMklMatMul(params,
666 matmul_prim);
667 }
668 }
669
670 return matmul_prim;
671 }
672
673 private:
MklMatMulPrimitiveFactory()674 MklMatMulPrimitiveFactory() {}
~MklMatMulPrimitiveFactory()675 ~MklMatMulPrimitiveFactory() {}
676
GetInstance()677 static MklMatMulPrimitiveFactory& GetInstance() {
678 static MklMatMulPrimitiveFactory instance_;
679 return instance_;
680 }
681
CreateKey(const MklMatMulParams & params)682 static string CreateKey(const MklMatMulParams& params) {
683 string prefix = "matmul_";
684 FactoryKeyCreator key_creator;
685 key_creator.AddAsKey(prefix);
686 key_creator.AddAsKey(params.a_dims);
687 key_creator.AddAsKey(params.b_dims);
688 key_creator.AddAsKey(params.c_dims);
689 key_creator.AddAsKey(params.a_strides);
690 key_creator.AddAsKey(params.b_strides);
691 key_creator.AddAsKey(params.c_strides);
692 key_creator.AddAsKey(typeid(T).name());
693
694 return key_creator.GetKey();
695 }
696
GetMklMatMul(const MklMatMulParams & params)697 MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
698 string key = CreateKey(params);
699 return this->GetOp(key);
700 }
701
SetMklMatMul(const MklMatMulParams & params,MklPrimitive * op)702 void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
703 string key = CreateKey(params);
704 this->SetOp(key, op);
705 }
706 };
707
708 template <typename T>
709 void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
710 float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
711 float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) {
712 using dims = mkldnn::memory::dims;
713
714 // Prepare strides based on the transa and transb flags: transposed
715 // matrices have strides swapped
716 dims a_dims = dims{m, k};
717 dims b_dims = dims{k, n};
718 dims c_dims = dims{m, n};
719 dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
720 dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
721 dims c_strides = dims{ldc, 1};
722
723 // MklMatMul uses const alpha and beta, make guarantee here to ensure
724 // they are never changed.
725 DCHECK_EQ(alpha, 1.0f);
726 DCHECK_EQ(beta, 0.f);
727
728 MklMatMulParams params(a_dims, b_dims, c_dims, a_strides, b_strides,
729 c_strides);
730 MklMatMulPrimitive<T>* matmul_prim =
731 MklMatMulPrimitiveFactory<T>::Get(params, 0);
732
733 // Execute matmul primitive.
734 std::shared_ptr<stream> cpu_stream;
735 MklDnnThreadPool eigen_tp(ctx);
736 cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
737 matmul_prim->Execute(a, b, c, cpu_stream);
738 }
739
740 } // anonymous namespace
741
742 } // namespace tensorflow
743
744 #endif // INTEL_MKL
745 #endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
746