• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 #ifdef INTEL_MKL
18 
19 #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h"
20 
21 #include <algorithm>
22 #include <map>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "mkldnn.hpp"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/numeric_op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_slice.h"
36 #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h"
37 #include "tensorflow/core/kernels/no_op.h"
38 #include "tensorflow/core/kernels/ops_util.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/gtl/array_slice.h"
41 #include "tensorflow/core/lib/strings/numbers.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/macros.h"
46 #include "tensorflow/core/util/mkl_util.h"
47 #include "tensorflow/core/util/padding.h"
48 #include "tensorflow/core/util/tensor_format.h"
49 
50 using mkldnn::convolution_forward;
51 using mkldnn::prop_kind;
52 using mkldnn::stream;
53 using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
54 using ReorderPd = mkldnn::reorder::primitive_desc;
55 
56 namespace tensorflow {
57 // This structure aggregates multiple inputs to Conv2DFwd* methods.
58 struct MklConvFwdParams {
59   memory::dims src_dims;
60   memory::dims filter_dims;
61   memory::dims bias_dims;
62   memory::dims dst_dims;
63   memory::dims strides;
64   memory::dims dilations;
65   memory::dims padding_left;
66   memory::dims padding_right;
67   MklTensorFormat tf_fmt;
68   bool native_format;
69   string dtypes = string("");
70   struct PostOpParam {
71     string name;
72     mkldnn::algorithm alg;
73     std::vector<float> param;
74     std::string partial_key;
75   };
76   std::vector<PostOpParam> post_op_params;
77 
MklConvFwdParamstensorflow::MklConvFwdParams78   MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
79                    memory::dims bias_dims, memory::dims dst_dims,
80                    memory::dims strides, memory::dims dilations,
81                    memory::dims padding_left, memory::dims padding_right,
82                    MklTensorFormat tf_fmt, bool native_format)
83       : src_dims(src_dims),
84         filter_dims(filter_dims),
85         bias_dims(bias_dims),
86         dst_dims(dst_dims),
87         strides(strides),
88         dilations(dilations),
89         padding_left(padding_left),
90         padding_right(padding_right),
91         tf_fmt(tf_fmt),
92         native_format(native_format) {}
93 };
94 
95 // With quantization, input, filter, and output can have different types
96 // so we use different template parameter for each type
97 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
98 class MklConvFwdPrimitive : public MklPrimitive {
99  public:
MklConvFwdPrimitive(const MklConvFwdParams & convFwdDims)100   explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
101       : MklPrimitive(engine(engine::kind::cpu, 0)) {
102     // Create convolution primitive
103     if (context_.conv_fwd == nullptr) {
104       Setup(convFwdDims);
105     }
106   }
~MklConvFwdPrimitive()107   ~MklConvFwdPrimitive() {}
108 
109   // Convolution forward execute with bias
110   //   src_data:    input data buffer of src
111   //   filter_data: input data buffer of filter (weights)
112   //   bias_data:   input data buffer of bias
113   //   dst_data:    output data buffer of dst
Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream)114   void Execute(const Tinput* src_data, const Tfilter* filter_data,
115                const Tbias* bias_data, const Toutput* dst_data,
116                std::shared_ptr<stream> fwd_stream) {
117 #ifdef ENABLE_MKLDNN_THREADPOOL
118     // TODO: Create a common function and avoid the duplicate code
119     context_.src_mem->set_data_handle(
120         static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
121     context_.filter_mem->set_data_handle(
122         static_cast<void*>(const_cast<Tfilter*>(filter_data)), *fwd_stream);
123     if (bias_data != nullptr) {
124       context_.bias_mem->set_data_handle(
125           static_cast<void*>(const_cast<Tbias*>(bias_data)), *fwd_stream);
126     }
127     context_.dst_mem->set_data_handle(
128         static_cast<void*>(const_cast<Toutput*>(dst_data)), *fwd_stream);
129 #else
130     context_.src_mem->set_data_handle(
131         static_cast<void*>(const_cast<Tinput*>(src_data)));
132     context_.filter_mem->set_data_handle(
133         static_cast<void*>(const_cast<Tfilter*>(filter_data)));
134     if (bias_data != nullptr) {
135       context_.bias_mem->set_data_handle(
136           static_cast<void*>(const_cast<Tbias*>(bias_data)));
137     }
138     context_.dst_mem->set_data_handle(
139         static_cast<void*>(const_cast<Toutput*>(dst_data)));
140 #endif  // ENABLE_MKLDNN_THREADPOOL
141 
142     DCHECK_EQ(context_.fwd_primitives.size(),
143               context_.fwd_primitives_args.size());
144     for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
145       context_.fwd_primitives.at(i).execute(*fwd_stream,
146                                             context_.fwd_primitives_args.at(i));
147     }
148 
149     // After execution, set data handle back
150     context_.src_mem->set_data_handle(DummyData);
151     context_.filter_mem->set_data_handle(DummyData);
152     if (bias_data != nullptr) {
153       context_.bias_mem->set_data_handle(DummyData);
154     }
155     context_.dst_mem->set_data_handle(DummyData);
156   }
157 
158   // Convolution forward execute without bias
159   //   src_data:    input data buffer of src
160   //   filter_data: input data buffer of filter (weights)
161   //   dst_data:    output data buffer of dst
Execute(const Tinput * src_data,const Tfilter * filter_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream)162   void Execute(const Tinput* src_data, const Tfilter* filter_data,
163                const Toutput* dst_data, std::shared_ptr<stream> fwd_stream) {
164     Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
165   }
166 
GetPrimitiveDesc() const167   std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
168     return context_.fwd_pd;
169   }
170 
171  private:
172   // Primitive reuse context for Conv2D Fwd op
173   struct ConvFwdContext {
174     // MKL-DNN memory
175     std::shared_ptr<mkldnn::memory> src_mem;
176     std::shared_ptr<mkldnn::memory> filter_mem;
177     std::shared_ptr<mkldnn::memory> bias_mem;
178     std::shared_ptr<mkldnn::memory> dst_mem;
179 
180     // Desc & primitive desc
181     std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
182 
183     // Memory desc
184     std::shared_ptr<mkldnn::memory::desc> src_md;
185     std::shared_ptr<mkldnn::memory::desc> filter_md;
186     std::shared_ptr<mkldnn::memory::desc> bias_md;
187     std::shared_ptr<mkldnn::memory::desc> dst_md;
188 
189     // Convolution primitive
190     std::shared_ptr<ConvFwdPd> fwd_pd;
191     std::shared_ptr<mkldnn::primitive> conv_fwd;
192 
193     std::vector<mkldnn::primitive> fwd_primitives;
194     std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
195 
ConvFwdContexttensorflow::MklConvFwdPrimitive::ConvFwdContext196     ConvFwdContext()
197         : src_mem(nullptr),
198           filter_mem(nullptr),
199           bias_mem(nullptr),
200           dst_mem(nullptr),
201           fwd_desc(nullptr),
202           src_md(nullptr),
203           filter_md(nullptr),
204           bias_md(nullptr),
205           fwd_pd(nullptr),
206           conv_fwd(nullptr) {}
207   };
208 
Setup(const MklConvFwdParams & convFwdDims)209   void Setup(const MklConvFwdParams& convFwdDims) {
210     memory::format_tag user_data_fmt;
211     if (convFwdDims.native_format) {
212       user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
213     } else {
214       // Create memory descriptors for convolution data w/ no specified format
215       user_data_fmt = memory::format_tag::any;
216     }
217     context_.src_md.reset(new memory::desc(
218         {convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
219 
220     context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
221                                               MklDnnType<Tfilter>(),
222                                               memory::format_tag::any));
223 
224     context_.dst_md.reset(new memory::desc(
225         {convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
226 
227     if (!convFwdDims.bias_dims.empty())
228       context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
229                                               MklDnnType<Tbias>(),
230                                               memory::format_tag::any));
231 
232     // Create a convolution descriptor
233     if (!convFwdDims.bias_dims.empty()) {
234       context_.fwd_desc.reset(new convolution_forward::desc(
235           prop_kind::forward, mkldnn::algorithm::convolution_direct,
236           *context_.src_md, *context_.filter_md, *context_.bias_md,
237           *context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
238           convFwdDims.padding_left, convFwdDims.padding_right));
239     } else {
240       context_.fwd_desc.reset(new convolution_forward::desc(
241           prop_kind::forward, mkldnn::algorithm::convolution_direct,
242           *context_.src_md, *context_.filter_md, *context_.dst_md,
243           convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
244           convFwdDims.padding_right));
245     }
246 
247     context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
248 
249     // Check if there is any fusions as post-ops
250     auto const& post_op_params = convFwdDims.post_op_params;
251     mkldnn::primitive_attr post_ops_attr;
252     mkldnn::post_ops post_ops;
253     if (!post_op_params.empty()) {
254       for (auto const& post_op_param : post_op_params) {
255         if (post_op_param.name == "activation") {
256           DCHECK_EQ(post_op_param.param.size(), 3);
257           float op_scale = post_op_param.param[0];
258           float op_alpha = post_op_param.param[1];
259           float op_beta = post_op_param.param[2];
260           post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
261                                   op_beta);
262         } else if (post_op_param.name == "sum") {
263           DCHECK_EQ(post_op_param.param.size(), 1);
264           float op_scale = post_op_param.param[0];
265           post_ops.append_sum(op_scale);
266         } else if (post_op_param.name == "output_scale") {
267           if (post_op_param.param.size() == 1) {
268             post_ops_attr.set_output_scales(0, post_op_param.param);
269           } else {
270             post_ops_attr.set_output_scales(2, post_op_param.param);
271           }
272         } else {
273           DCHECK((post_op_param.name == "activation") ||
274                  (post_op_param.name == "sum") ||
275                  (post_op_param.name == "output_scale"));
276         }
277       }
278       post_ops_attr.set_post_ops(post_ops);
279       context_.fwd_pd.reset(
280           new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_));
281     } else {
282       context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
283     }
284 
285     // Create memory primitive based on dummy data
286     context_.src_mem.reset(
287         new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
288     context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
289                                          cpu_engine_, DummyData));
290     context_.dst_mem.reset(
291         new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
292 
293     // Create convolution primitive and add it to net
294     if (!convFwdDims.bias_dims.empty()) {
295       context_.bias_mem.reset(new memory(
296           {{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
297           cpu_engine_, DummyData));
298       context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
299       context_.fwd_primitives_args.push_back(
300           {{MKLDNN_ARG_SRC, *context_.src_mem},
301            {MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
302            {MKLDNN_ARG_BIAS, *context_.bias_mem},
303            {MKLDNN_ARG_DST, *context_.dst_mem}});
304     } else {
305       context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
306       context_.fwd_primitives_args.push_back(
307           {{MKLDNN_ARG_SRC, *context_.src_mem},
308            {MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
309            {MKLDNN_ARG_DST, *context_.dst_mem}});
310     }
311     context_.fwd_primitives.push_back(*context_.conv_fwd);
312   }
313 
314   struct ConvFwdContext context_;
315 };
316 
317 // TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory.
318 // But removing the need for type in MklPrimitiveFactory is going to require
319 // change to every MKL op. So not doing it now. Instead passing float.
320 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
321 class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<float> {
322  public:
Get(const MklConvFwdParams & convFwdDims,bool do_not_cache)323   static MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* Get(
324       const MklConvFwdParams& convFwdDims, bool do_not_cache) {
325     MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* conv_fwd = nullptr;
326 
327     if (do_not_cache) {
328       // Always create a new primitive
329       conv_fwd =
330           new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>(convFwdDims);
331     } else {
332       // Try to find a suitable one in pool
333       conv_fwd =
334           dynamic_cast<MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>*>(
335               MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias,
336                                          Toutput>::GetInstance()
337                   .GetConvFwd(convFwdDims));
338       if (conv_fwd == nullptr) {
339         conv_fwd = new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>(
340             convFwdDims);
341         MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias,
342                                    Toutput>::GetInstance()
343             .SetConvFwd(convFwdDims, conv_fwd);
344       }
345     }
346 
347     return conv_fwd;
348   }
349 
350  private:
MklConvFwdPrimitiveFactory()351   MklConvFwdPrimitiveFactory() {}
~MklConvFwdPrimitiveFactory()352   ~MklConvFwdPrimitiveFactory() {}
353 
354   static const int kDilationH = 0, kDilationW = 1;
355 
GetInstance()356   static MklConvFwdPrimitiveFactory& GetInstance() {
357     static MklConvFwdPrimitiveFactory instance_;
358     return instance_;
359   }
360 
CreateKey(const MklConvFwdParams & convFwdDims)361   static string CreateKey(const MklConvFwdParams& convFwdDims) {
362     string prefix = "conv_fwd_";
363     FactoryKeyCreator key_creator;
364     key_creator.AddAsKey(prefix);
365     key_creator.AddAsKey(convFwdDims.src_dims);
366     key_creator.AddAsKey(convFwdDims.filter_dims);
367     key_creator.AddAsKey(convFwdDims.bias_dims);
368     key_creator.AddAsKey(convFwdDims.dst_dims);
369     key_creator.AddAsKey(convFwdDims.strides);
370     key_creator.AddAsKey(convFwdDims.dilations);
371     key_creator.AddAsKey(convFwdDims.padding_left);
372     key_creator.AddAsKey(convFwdDims.padding_right);
373     key_creator.AddAsKey(convFwdDims.dtypes);
374     if (convFwdDims.native_format) {
375       key_creator.AddAsKey(convFwdDims.tf_fmt);
376     }
377 
378     // Generate keys for post-ops
379     for (auto const& post_op_param : convFwdDims.post_op_params) {
380       key_creator.AddAsKey(post_op_param.name);
381       if (post_op_param.name == "activation") {
382         DCHECK_EQ(post_op_param.param.size(), 3);
383         for (auto& param : post_op_param.param) {
384           key_creator.AddAsKey(param);
385         }
386       } else if (post_op_param.name == "sum") {
387         DCHECK_EQ(post_op_param.param.size(), 1);
388         for (auto& param : post_op_param.param) {
389           key_creator.AddAsKey(param);
390         }
391       } else if (post_op_param.name == "output_scale") {
392         key_creator.AddAsKey(post_op_param.partial_key);
393       } else {
394         return string("not_a_key");
395       }
396     }
397 
398     return key_creator.GetKey();
399   }
400 
GetConvFwd(const MklConvFwdParams & convFwdDims)401   MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) {
402     string key = CreateKey(convFwdDims);
403     return this->GetOp(key);
404   }
405 
SetConvFwd(const MklConvFwdParams & convFwdDims,MklPrimitive * op)406   void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
407     string key = CreateKey(convFwdDims);
408     this->SetOp(key, op);
409   }
410 };
411 
412 // Base class for convolution forward operations
413 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
414           typename Toutput, typename Ttemp_output, typename Tpadding,
415           bool bias_enabled, bool pad_enabled, bool is_depthwise,
416           bool native_format>
417 class MklConvOp : public OpKernel {
418  public:
~MklConvOp()419   ~MklConvOp() {}
420 
MklConvOp(OpKernelConstruction * context)421   explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
422     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
423 
424     // Conv and QuantizedConv ops have different padding attributes
425     // (`padding_list` versus `explicit_paddings`). But one and only one
426     // attribute is expected.
427     OP_REQUIRES(
428         context,
429         !(context->HasAttr("padding_list") &&
430           context->HasAttr("explicit_paddings")),
431         errors::InvalidArgument("Can only have 1 `padding` list at most"));
432     if (context->HasAttr("padding_list")) {
433       OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_));
434     }
435     if (context->HasAttr("explicit_paddings")) {
436       OP_REQUIRES_OK(context,
437                      context->GetAttr("explicit_paddings", &padding_list_));
438     }
439 
440     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
441     string data_format;
442     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
443     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
444                 errors::InvalidArgument("Invalid data format"));
445     OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5),
446                 errors::InvalidArgument("Sliding window strides field must "
447                                         "specify 4 or 5 dimensions"));
448 
449     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
450     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
451     OP_REQUIRES(
452         context, stride_n == 1 && stride_c == 1,
453         errors::Unimplemented("Current implementation does not yet support "
454                               "strides in the batch and depth dimensions."));
455 
456     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
457     is_filter_const_ = false;
458     if (context->HasAttr("is_filter_const")) {
459       OP_REQUIRES_OK(context,
460                      context->GetAttr("is_filter_const", &is_filter_const_));
461     }
462 
463     if (strides_.size() == 4) {
464       OP_REQUIRES(context, dilations_.size() == 4,
465                   errors::InvalidArgument("Sliding window dilations field must "
466                                           "specify 4 dimensions"));
467       const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
468       const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
469       const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
470       const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
471       OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
472                   errors::InvalidArgument(
473                       "Current implementation does not yet support "
474                       "dilations in the batch and depth dimensions."));
475       OP_REQUIRES(
476           context, dilation_h > 0 && dilation_w > 0,
477           errors::InvalidArgument("Dilated rates should be larger than 0."));
478     } else if (strides_.size() == 5) {
479       OP_REQUIRES(context, dilations_.size() == 5,
480                   errors::InvalidArgument("Dilation rates field must "
481                                           "specify 5 dimensions"));
482       OP_REQUIRES(context,
483                   (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
484                    GetTensorDim(dilations_, data_format_, 'C') == 1),
485                   errors::InvalidArgument(
486                       "Current implementation does not yet support "
487                       "dilations rates in the batch and depth dimensions."));
488       OP_REQUIRES(
489           context,
490           (GetTensorDim(dilations_, data_format_, '0') > 0 &&
491            GetTensorDim(dilations_, data_format_, '1') > 0 &&
492            GetTensorDim(dilations_, data_format_, '2') > 0),
493           errors::InvalidArgument("Dilated rates should be larger than 0."));
494     }
495   }
496 
Compute(OpKernelContext * context)497   void Compute(OpKernelContext* context) override {
498     try {
499       // Input tensors
500       const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src);
501       const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
502       MklDnnShape src_mkl_shape, filter_mkl_shape;
503       GetMklShape(context, kInputIndex_Src, &src_mkl_shape, native_format);
504       GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape,
505                   native_format);
506 
507       OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(),
508                   errors::InvalidArgument("Filter should not be in "
509                                           "Mkl Layout"));
510 
511       MklDnnData<Tinput> src(&cpu_engine_);
512       MklDnnData<Tfilter> filter(&cpu_engine_);
513 
514       memory::dims src_dims, filter_dims, padding_left, padding_right,
515           dilations, strides;
516       memory::dims dst_dims_tf_order, dst_dims_mkl_order;
517 
518       // For any Conv with `EXPLICIT` padding, get padding from `padding_list`
519       // attribute. Otherwise, get it from one of the inputs.
520       bool pad_attr_enabled = false;
521       for (auto const& padding_val : padding_list_) {
522         if (padding_val) {
523           pad_attr_enabled = true;
524 
525           break;
526         }
527       }
528 
529       if (fuse_pad_ || pad_attr_enabled) {
530         PadWithConvFusion(context, padding_left, padding_right,
531                           pad_attr_enabled);
532       }
533 
534       // Get shapes of input tensors in MKL-DNN order
535       MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
536                               dilations_);
537       auto src_tf_shape = GetTfShape(context, kInputIndex_Src, native_format);
538       auto filter_tf_shape =
539           GetTfShape(context, kInputIndex_Filter, native_format);
540       conv_utl.GetConvFwdSizesInMklOrder(
541           src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
542           &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
543           &padding_right, (fuse_pad_ || pad_attr_enabled), is_depthwise);
544 
545       if (!context->status().ok()) return;
546 
547       // Check for corner case - if there is nothing to compute, return.
548       TensorShape dst_tf_shape = MklDnnDimsToTFShape(dst_dims_tf_order);
549 
550       // Corner cases: output with 0 elements and 0 batch size.
551       Tensor* dst_tensor = nullptr;
552       bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) &&
553                                  typeid(Tinput) == typeid(Toutput) &&
554                                  (typeid(Tinput) == typeid(float) ||
555                                   typeid(Tinput) == typeid(bfloat16))) &&
556                                 !native_format;
557       if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
558         MklDnnShape dst_mkl_shape;
559         dst_mkl_shape.SetMklTensor(false);
560         AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
561                                   src_tf_shape, dst_mkl_shape, native_format);
562 
563         // MklConv2D/3D also outputs converted filter as 2nd output.
564         filter_mkl_shape.SetMklTensor(false);
565         Tensor* output_filter_tensor = nullptr;
566         if (emit_filter_output) {
567           filter_mkl_shape.SetMklTensor(false);
568           AllocateOutputSetMklShape(context, kOutputIndex_Filter,
569                                     &output_filter_tensor, filter_tf_shape,
570                                     filter_mkl_shape);
571         }
572         return;
573       }
574 
575       bool is_conv2d = (strides_.size() == 4);
576 
577       if (!is_conv2d) {
578         OP_REQUIRES(
579             context, !pad_enabled,
580             errors::InvalidArgument("Pad + Conv fusion only works for 2D"));
581         OP_REQUIRES(
582             context, !fuse_pad_,
583             errors::InvalidArgument("Pad+Conv fusion only works for 2D"));
584       }
585 
586       // TODO(gzmkl) 3-D support for Depthwise is not there
587       if (is_depthwise) {
588         OP_REQUIRES(context, is_conv2d,
589                     errors::InvalidArgument(
590                         "Only 2D convolution is supported for depthwise."));
591       }
592 
593       // Create memory for user data.
594       // Describe how the inputs and outputs of Convolution look like. Also
595       // specify buffers containing actual input and output data.
596       auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
597                               : TFDataFormatToMklDnn3DDataFormat(data_format_);
598 
599       auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
600       // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
601       OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
602                   errors::InvalidArgument("Invalid data format"));
603 
604       // If input is in MKL layout, then simply grab the layout; otherwise,
605       // construct TF layout for input.
606       // For constructing TF layout for input, although input shape (src_dims)
607       // is required to be in MKL-DNN order, the input layout is actually in
608       // TF layout depending on the data format:
609       //     Conv2D: NHWC or NCHW
610       //     Conv3D: NDHWC or NCDHW
611       auto src_md =
612           src_mkl_shape.IsMklTensor()
613               ? src_mkl_shape.GetMklLayout()
614               : memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
615       src.SetUsrMem(src_md, &src_tensor);
616 
617       // Although filter shape (filter_dims) required is in MKL-DNN order,
618       // the layout is Tensorflow's layout (HWIO) and (HWIGO) for
619       // depthwise/group convolutions.
620       auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo
621                                                      : memory::format_tag::hwio)
622                                      : memory::format_tag::dhwio;
623 
624       DCHECK(!filter_mkl_shape.IsMklTensor());
625       auto filter_md =
626           filter_mkl_shape.IsMklTensor()
627               ? filter_mkl_shape.GetMklLayout()
628               : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format);
629       filter.SetUsrMem(filter_md, &filter_tensor);
630 
631       // MKL-DNN dilations start from 0.
632       for (int i = 0; i < dilations.size(); ++i) --dilations[i];
633 
634       // In some cases, primitive descriptor could potentially contain
635       // large buffers. As a result, we don't cache these primitives if the
636       // environment variable `TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE` is set to True.
637       // MKL-DNN allocates buffers in the following cases:
638       //   1. Legacy CPU without AVX512/AVX2, or
639       //   2. 1x1 convolution with strides != 1
640       bool do_not_cache =
641           MklPrimitiveFactory<Tinput>::IsPrimitiveMemOptEnabled() &&
642           (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) &&
643           (MklPrimitiveFactory<Tinput>::IsLegacyPlatform() ||
644            IsConv1x1StrideNot1(filter_dims, strides));
645 
646       // Get a conv2d fwd from primitive pool
647       MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Ttemp_output>* conv_fwd =
648           nullptr;
649       memory::dims bias_dims = {};
650       if (fuse_biasadd_) {
651         conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
652       }
653       MklConvFwdParams convFwdDims(
654           src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS,
655           dst_dims_mkl_order, strides, dilations, padding_left, padding_right,
656           tf_fmt, native_format);
657 
658       // TODO(mdfaijul): Extend the basic parameters for data types and fusions
659       this->ExtendConvFwdParams(context, convFwdDims);
660       conv_fwd =
661           MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get(
662               convFwdDims, do_not_cache);
663       // Allocate output tensors `dst_tensor` and `filter_out_tensor`
664       MklDnnShape output_mkl_shape;
665       std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
666       AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
667                            &output_mkl_shape, &dst_tensor);
668 
669       Tensor* filter_out_tensor = nullptr;
670       if (emit_filter_output) {
671         AllocateFilterOutputTensor(context, *conv_fwd_pd,
672                                    TFShapeToMklDnnDims(filter_tf_shape),
673                                    &filter_out_tensor);
674       }
675 
676       Ttemp_output* dst_data =
677           reinterpret_cast<Ttemp_output*>(dst_tensor->flat<Toutput>().data());
678 
679       // Check whether src and filter need to be reordered.
680       Tinput* src_data = nullptr;
681       if (src_md != conv_fwd_pd->src_desc()) {
682         src.SetUsrMem(src_md, &src_tensor);
683         src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context);
684         src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
685       } else {
686         src_data = static_cast<Tinput*>(
687             const_cast<Tinput*>(src_tensor.flat<Tinput>().data()));
688       }
689 
690       Tfilter* filter_data = nullptr;
691       if (filter_md != conv_fwd_pd->weights_desc()) {
692         bool is_filter_cached = false;
693         // If filter is a constant, we can avoid the conversion of filter from
694         // Tensorflow format to MKL format by caching the filter when it is
695         // converted for the first time. This cached filter can then be reused
696         // in subsequent iterations.
697         if (is_filter_const_) {
698           if (IsFilterCacheEmpty(context)) {
699             // Cache filter if it is not already cached.
700             CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
701                         filter, filter_md, filter_mkl_shape);
702           }
703           filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc());
704           is_filter_cached = (filter_data != nullptr);
705         }
706         if (!is_filter_cached) {
707           filter.SetUsrMem(filter_md, &filter_tensor);
708           if (filter_out_tensor == nullptr) {
709             filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_,
710                                        context);
711           } else {
712             filter.CheckReorderToOpMem(
713                 conv_fwd_pd->weights_desc(),
714                 filter.GetTensorBuffer(filter_out_tensor), cpu_engine_,
715                 context);
716           }
717           filter_data =
718               static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
719         }
720       } else {
721         filter_data = static_cast<Tfilter*>(
722             const_cast<Tfilter*>(filter_tensor.flat<Tfilter>().data()));
723       }
724 
725       // Execute convolution
726       std::shared_ptr<stream> fwd_cpu_stream;
727       fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine()));
728       if (fuse_biasadd_) {
729         const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
730         Tbias* bias_data =
731             this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
732         conv_fwd->Execute(src_data, filter_data, bias_data, dst_data,
733                           fwd_cpu_stream);
734       } else {
735         conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream);
736       }
737 
738       // Delete primitive since it is not cached.
739       if (do_not_cache) delete conv_fwd;
740 
741     } catch (mkldnn::error& e) {
742       string error_msg = tensorflow::strings::StrCat(
743           "Status: ", e.status, ", message: ", string(e.message), ", in file ",
744           __FILE__, ":", __LINE__);
745       OP_REQUIRES_OK(
746           context,
747           errors::Aborted("Operation received an exception:", error_msg));
748     }
749   }
750 
PadWithConvFusion(OpKernelContext * context,memory::dims & padding_left,memory::dims & padding_right,bool pad_attr_enabled)751   void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left,
752                          memory::dims& padding_right, bool pad_attr_enabled) {
753     Tpadding* paddings = nullptr;
754     if (pad_attr_enabled) {
755       paddings = padding_list_.data();
756     } else {
757       const Tensor& paddings_tf = MklGetInput(context, input_index_pad_);
758       OP_REQUIRES(context, paddings_tf.dims() == 2,
759                   errors::InvalidArgument("paddings must be 2-dimensional: ",
760                                           paddings_tf.shape().DebugString()));
761       // Flatten tensor to get individual paddings.
762       paddings = static_cast<Tpadding*>(
763           const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data()));
764     }
765     // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf)
766     // will be zero.
767     // Example:
768     // paddings_tf = [ [0, 0] [1, 2] [3, 4] [0, 0] ],
769     // flat method = row-major, then:
770     // paddings = {0, 0, 1, 2, 3, 4, 0, 0}.
771     // Hence, the values are: top = 1, bottom = 2, left = 3, right = 4.
772     //
773     // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of
774     // paddings(_tf) will be zero.
775     // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}.
776     int64 pad_top = 0, pad_left = 0;
777     int64 pad_bottom = 0, pad_right = 0;
778     string data_format = ToString(data_format_);
779     if (data_format == "NHWC") {
780       pad_top = paddings[2];
781       pad_bottom = paddings[3];
782       pad_left = paddings[4];
783       pad_right = paddings[5];
784     } else if (data_format == "NCHW") {
785       pad_top = paddings[4];
786       pad_bottom = paddings[5];
787       pad_left = paddings[6];
788       pad_right = paddings[7];
789     }
790     // Create padding arrays for MKL-DNN convolutions.
791     // MKL-DNN uses asymmetric padding.
792     padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
793     padding_right = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
794   }
795 
796  protected:
set_fuse_biasadd(bool fuse_biasadd)797   void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; }
set_fuse_activation(bool fuse_activation,mkldnn::algorithm activation_alg,float alpha_or_upbound=0.0)798   void set_fuse_activation(bool fuse_activation,
799                            mkldnn::algorithm activation_alg,
800                            float alpha_or_upbound = 0.0) {
801     fuse_activation_ = fuse_activation;
802     activation_alg_ = activation_alg;
803     // This variable is used for alpha in leakyrelu or upper bound in relu6
804     // depending on the context
805     alpha_or_upbound_ = alpha_or_upbound;
806   }
set_fuse_pad(bool fuse_pad)807   void set_fuse_pad(bool fuse_pad) {
808     fuse_pad_ = fuse_pad;
809     // In PadwithFusedConv OP, pad is the fourth index.
810     input_index_pad_ = 3;
811   }
set_fuse_add(bool fuse_add)812   void set_fuse_add(bool fuse_add) { fuse_add_ = fuse_add; }
813 
814   // This method is for the base class MklConvOp, which handles the
815   // floating point implementation of Conv. The quantized conv implementations
816   // will use overridden versions of this method.
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)817   virtual void ExtendConvFwdParams(OpKernelContext* context,
818                                    MklConvFwdParams& params) {
819     // Create a string from data types of input, filter, bias, and output.
820     params.dtypes.append(typeid(Tinput).name());
821     params.dtypes.append(typeid(Tfilter).name());
822     params.dtypes.append(typeid(Tbias).name());
823     params.dtypes.append(typeid(Toutput).name());
824 
825     // Add fusions as post ops
826     // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
827     // checking `fuse_biasadd_` flag.
828     if (fuse_add_) {
829       params.post_op_params.push_back(
830           {"sum", mkldnn::algorithm::undef, {1.0}, ""});
831     }
832     if (fuse_activation_) {
833       params.post_op_params.push_back(
834           {"activation", activation_alg_, {1.0, alpha_or_upbound_, 0.0}, ""});
835     }
836   }
837 
GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv2d_fwd_pd,const Tensor & bias_tensor)838   virtual Tbias* GetBiasHandle(OpKernelContext* context,
839                                std::shared_ptr<ConvFwdPd>& conv2d_fwd_pd,
840                                const Tensor& bias_tensor) {
841     if (fuse_biasadd_) {
842       return static_cast<Tbias*>(
843           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
844     }
845     return nullptr;
846   }
847 
AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)848   virtual void AllocateOutputTensor(OpKernelContext* context,
849                                     const ConvFwdPd& conv_prim_desc,
850                                     const memory::dims& output_dims_mkl_order,
851                                     MklTensorFormat output_tf_format,
852                                     MklDnnShape* output_mkl_shape,
853                                     Tensor** output_tensor) {
854     DCHECK(output_tensor);
855     auto dst_md = conv_prim_desc.dst_desc();
856 
857     if (!std::is_same<Ttemp_output, Toutput>::value) {
858       dst_md.data.data_type =
859           static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
860     }
861 
862     // Allocate shape of MKL tensor
863     output_mkl_shape->SetMklTensor(true);
864     output_mkl_shape->SetMklLayout(&dst_md);
865     output_mkl_shape->SetElemType(MklDnnType<Toutput>());
866     output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
867                                   output_dims_mkl_order, output_tf_format);
868 
869     // Allocate shape of TF tensor
870     TensorShape output_tf_shape;
871     output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput)));
872     if (native_format) {
873       output_tf_shape = output_mkl_shape->GetTfShape();
874     }
875 
876     if (fuse_add_) {
877       const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add);
878       MklDnnShape add_mkl_shape;
879       GetMklShape(context, kInputIndex_Add, &add_mkl_shape, native_format);
880       // Forward the summand tensor to the output only if it has no other
881       // references, otherwise make a copy of it.
882       if (native_format && context->forward_input_to_output_with_shape(
883                                kInputIndex_Add, kOutputIndex_Dst,
884                                output_tf_shape, output_tensor)) {
885         return;
886       }
887       // Check if reorder is needed
888       if (!native_format && add_mkl_shape == *output_mkl_shape &&
889           ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
890                                               kOutputIndex_Dst, output_tensor,
891                                               add_mkl_shape, false)) {
892         return;
893       } else {
894         AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
895                                   output_tf_shape, *output_mkl_shape,
896                                   native_format);
897         auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
898             output_mkl_shape->GetTfDataFormat());
899         OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
900                     errors::InvalidArgument(
901                         "MklConvOp: AddN fusion: Invalid data format"));
902         auto add_md =
903             add_mkl_shape.IsMklTensor()
904                 ? add_mkl_shape.GetMklLayout()
905                 : memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
906                                output_format_tag);
907         void* add_buf = static_cast<void*>(
908             const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
909         void* dst_buf =
910             static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
911         if (native_format) {
912           // We are simply deep copying the add_tensor to output_tensor without
913           // changing memory layout, hence using same memory descriptor.
914           add_md = dst_md =
915               memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
916                            mkldnn::memory::format_tag::x);
917         }
918         fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf));
919         fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf));
920         auto reorder_desc =
921             ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
922 
923         CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
924                                 this->cpu_engine_, context);
925       }
926     } else {
927       AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
928                                 output_tf_shape, *output_mkl_shape,
929                                 native_format);
930     }
931   }
932 
933   engine cpu_engine_ = engine(engine::kind::cpu, 0);
934 
935  private:
936   std::shared_ptr<mkldnn::memory> fuse_add_src_;
937   std::shared_ptr<mkldnn::memory> fuse_add_dst_;
938   std::vector<int32> strides_;
939   std::vector<int32> dilations_;
940   std::vector<Tpadding> padding_list_;
941   bool is_filter_const_;
942   mutex mu_;
943   Padding padding_;
944   TensorFormat data_format_;
945   PersistentTensor cached_filter_data_ptensor_ TF_GUARDED_BY(mu_);
946   PersistentTensor cached_filter_md_ptensor_ TF_GUARDED_BY(mu_);
947 
948   // Initialize to values the template is instantiated with
949   bool fuse_biasadd_ = bias_enabled;
950   bool fuse_activation_ = false;
951   bool fuse_pad_ = pad_enabled;
952   bool fuse_add_ = false;
953 
954   // This variable is used for alpha in leakyrelu or upper bound in relu6
955   // depending on the context
956   float alpha_or_upbound_ = 0.0;
957   mkldnn::algorithm activation_alg_ = mkldnn::algorithm::undef;
958 
959   int input_index_pad_ = 2;
960 
961   const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2;
962   const int kInputIndex_Add = 3;
963   const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
964   const int kDilationH = 0, kDilationW = 1;
965 
GetFilterTfDataFormat(const MklDnnShape * filter_mkl_shape,const ConvFwdPd & conv_prim_desc) const966   MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape,
967                                         const ConvFwdPd& conv_prim_desc) const {
968     DCHECK(filter_mkl_shape);
969     return filter_mkl_shape->GetTfDataFormat();
970   }
971 
972   // Allocate persistent tensors for cached filter data and
973   // cached filter memory descriptor (data format)
AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor,const MklDnnShape * filter_mkl_shape)974   void AllocatePersistentTensor(OpKernelContext* context,
975                                 const ConvFwdPd& conv_prim_desc,
976                                 Tensor** filter_tensor,
977                                 const MklDnnShape* filter_mkl_shape) {
978     DCHECK(filter_tensor);
979     TensorShape filter_tf_shape;
980     filter_tf_shape.AddDim(
981         (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter)));
982     OP_REQUIRES_OK(context, context->allocate_persistent(
983                                 DataTypeToEnum<Tfilter>::value, filter_tf_shape,
984                                 &cached_filter_data_ptensor_, filter_tensor));
985 
986     Tensor* second_tensor = nullptr;
987 
988     // There is no tensor format in DNNL 1.x. So we cache the complete filter
989     // descriptor as flat byte array.
990     TensorShape cached_filter_md_shape;
991     memory::desc weights_desc = conv_prim_desc.weights_desc();
992     // We don't use .get_size() method of memory::desc since it returns size
993     // required to store primitive's input memory. It is much more than size of
994     // memory::desc itself.
995     cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8));
996     OP_REQUIRES_OK(context, context->allocate_persistent(
997                                 DT_UINT8, cached_filter_md_shape,
998                                 &cached_filter_md_ptensor_, &second_tensor));
999     *reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
1000         weights_desc;
1001   }
1002 
AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor)1003   void AllocatePersistentTensor(OpKernelContext* context,
1004                                 const ConvFwdPd& conv_prim_desc,
1005                                 Tensor** filter_tensor) {
1006     AllocatePersistentTensor(context, conv_prim_desc, filter_tensor, nullptr);
1007   }
1008 
AllocateFilterOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & filter_dims_tf_order,Tensor ** filter_tensor)1009   void AllocateFilterOutputTensor(OpKernelContext* context,
1010                                   const ConvFwdPd& conv_prim_desc,
1011                                   const memory::dims& filter_dims_tf_order,
1012                                   Tensor** filter_tensor) {
1013     DCHECK(filter_tensor);
1014     auto filter_md = conv_prim_desc.weights_desc();
1015 
1016     // Allocate shape of MKL tensor
1017     MklDnnShape filter_mkl_shape;
1018     filter_mkl_shape.SetMklTensor(true);
1019     filter_mkl_shape.SetMklLayout(&filter_md);
1020     filter_mkl_shape.SetElemType(MklDnnType<Tfilter>());
1021 
1022     // The format of the filter is actually OIhw8i8o, but TF doesn't support
1023     // this format. Just use format::blocked for now because the layout
1024     // is stored in the MKL data.
1025     filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
1026                                  filter_dims_tf_order,
1027                                  MklTensorFormat::FORMAT_BLOCKED);
1028 
1029     // Allocate the data space for the filter to propagate as TF tensor.
1030     TensorShape filter_tf_shape;
1031     filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter)));
1032 
1033     AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
1034                               filter_tf_shape, filter_mkl_shape);
1035   }
1036 
1037   // TODO(intel-mkl): This function does not seem to be called. Remove it.
1038   // Prepare and execute net - checks for input and output reorders.
PrepareAndExecuteNet(const ConvFwdPd & conv_prim_desc,MklDnnData<Tinput> * src,MklDnnData<Tfilter> * filter,MklDnnData<Tbias> * bias,MklDnnData<Toutput> * output,Tensor * filter_out_tensor)1039   void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc,
1040                             MklDnnData<Tinput>* src,
1041                             MklDnnData<Tfilter>* filter,
1042                             MklDnnData<Tbias>* bias,
1043                             MklDnnData<Toutput>* output,
1044                             Tensor* filter_out_tensor) {
1045     DCHECK(filter_out_tensor);
1046 
1047     // Create reorders between user layout and MKL layout if it is needed and
1048     // add it to the net before convolution. No need to check for output
1049     // reorder as we propagate output layout to the next layer.
1050     src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_);
1051 
1052     // Rather than re-ordering to a temp buffer, reorder directly to the
1053     // filter output tensor
1054     filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(),
1055                                 filter->GetTensorBuffer(filter_out_tensor));
1056 
1057     // Create convolution primitive and add it to net.
1058     std::vector<primitive> net;
1059     std::vector<std::unordered_map<int, memory>> net_args;
1060     if (bias) {
1061       DCHECK(fuse_biasadd_);
1062       net.push_back(convolution_forward(conv_prim_desc));
1063       net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
1064                           {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
1065                           {MKLDNN_ARG_BIAS, bias->GetOpMem()},
1066                           {MKLDNN_ARG_DST, output->GetOpMem()}});
1067     } else {
1068       DCHECK(!fuse_biasadd_);
1069       net.push_back(convolution_forward(conv_prim_desc));
1070       net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
1071                           {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
1072                           {MKLDNN_ARG_DST, output->GetOpMem()}});
1073     }
1074     ExecutePrimitive(net, &net_args, cpu_engine_);
1075   }
1076 
1077   // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
1078   // be acquired before entering the function, since it is acquired
1079   // inside the function.
IsFilterCacheEmpty(OpKernelContext * context)1080   inline bool IsFilterCacheEmpty(OpKernelContext* context)
1081       TF_LOCKS_EXCLUDED(mu_) {
1082     tf_shared_lock lock(mu_);
1083     const Tensor& cached_filter_data_tensor =
1084         *cached_filter_data_ptensor_.AccessTensor(context);
1085     return (cached_filter_data_tensor.NumElements() == 0);
1086   }
1087 
1088   // Cache the converted filter in a persistent tensor.
1089   // Only one thread can execute this method at any given time.
CacheFilter(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tfilter * filter_data,const Tensor & filter_tensor,MklDnnData<Tfilter> & filter,const memory::desc & filter_md,const MklDnnShape & filter_mkl_shape)1090   void CacheFilter(OpKernelContext* context,
1091                    const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1092                    Tfilter* filter_data, const Tensor& filter_tensor,
1093                    MklDnnData<Tfilter>& filter, const memory::desc& filter_md,
1094                    const MklDnnShape& filter_mkl_shape) TF_LOCKS_EXCLUDED(mu_) {
1095     mutex_lock lock(mu_);
1096     const Tensor& cached_filter_data_tensor =
1097         *cached_filter_data_ptensor_.AccessTensor(context);
1098 
1099     // If filter is already cached, there's nothing to do.
1100     if (cached_filter_data_tensor.NumElements() > 0) {
1101       return;
1102     }
1103 
1104     // Otherwise, cache filter
1105     filter.SetUsrMem(filter_md, &filter_tensor);
1106     filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
1107                                this->cpu_engine_, context);
1108     filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
1109 
1110     Tensor* filter_tensor_ptr = nullptr;
1111     AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr,
1112                              &filter_mkl_shape);
1113     void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
1114     size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size();
1115     memcpy(cached_filter_data, filter_data, cached_filter_data_size);
1116   }
1117 
AreMemoryDescriptorsEqual(const memory::desc & filter_md,const Tensor & cached_filter_md)1118   bool AreMemoryDescriptorsEqual(const memory::desc& filter_md,
1119                                  const Tensor& cached_filter_md) {
1120     auto filter_md_data = filter_md.data;
1121     const char* filter_data = reinterpret_cast<const char*>(&filter_md_data);
1122 
1123     auto cached_filter_md_data = cached_filter_md.scalar<int64>()();
1124     const char* cached_filter_data =
1125         reinterpret_cast<const char*>(&cached_filter_md_data);
1126 
1127     for (size_t i = 0; i < sizeof(filter_md_data); ++i) {
1128       if (*filter_data++ != *cached_filter_data++) {
1129         return false;
1130       }
1131     }
1132     return true;
1133   }
1134 
GetCachedFilter(OpKernelContext * context,const memory::desc & filter_md)1135   Tfilter* GetCachedFilter(OpKernelContext* context,
1136                            const memory::desc& filter_md)
1137       TF_LOCKS_EXCLUDED(mu_) {
1138     tf_shared_lock lock(mu_);
1139     const Tensor& cached_filter_data =
1140         *cached_filter_data_ptensor_.AccessTensor(context);
1141     const Tensor& cached_filter_md =
1142         *cached_filter_md_ptensor_.AccessTensor(context);
1143 
1144     // Check if the memory descriptor of the cached weights is the same as
1145     // filter_md. If so, we can use the cached weights; otherwise
1146     // return nullptr.
1147     if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
1148       return static_cast<Tfilter*>(
1149           const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
1150     }
1151     return nullptr;
1152   }
1153 };
1154 
1155 // Base class for fused convolution forward operations
1156 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
1157           typename Toutput, typename Ttemp_output, typename Tpadding,
1158           bool pad_enabled, bool native_format>
1159 class MklFusedConvOp
1160     : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1161                        Tpadding, false, false, false, native_format> {
1162  public:
MklFusedConvOp(OpKernelConstruction * context)1163   explicit MklFusedConvOp(OpKernelConstruction* context)
1164       : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1165                   Tpadding, false, false, false, native_format>(context) {
1166     // Since we came here through the registration of _MklFusedConv2D, get
1167     // all information from 'fused_ops' and 'num_args'
1168     std::vector<string> fused_ops;
1169     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
1170 
1171     int num_args;
1172     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
1173     OP_REQUIRES(context, !fused_ops.empty(),
1174                 errors::InvalidArgument(
1175                     "Fused Conv2D must have at least one fused op."));
1176 
1177     if (fused_ops == std::vector<string>{"BiasAdd"}) {
1178       this->set_fuse_biasadd(true);
1179       OP_REQUIRES(context, num_args == 1,
1180                   errors::InvalidArgument(
1181                       "Fused Conv2D must have one extra argument: bias."));
1182     } else if (fused_ops == std::vector<string>{"Relu"}) {
1183       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
1184     } else if (fused_ops == std::vector<string>{"Relu6"}) {
1185       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
1186                                 6.0);
1187     } else if (fused_ops == std::vector<string>{"Elu"}) {
1188       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
1189     } else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
1190       float leakyrelu_alpha;
1191       OP_REQUIRES_OK(context,
1192                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
1193       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
1194                                 leakyrelu_alpha);
1195     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
1196       this->set_fuse_biasadd(true);
1197       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
1198       OP_REQUIRES(context, num_args == 1,
1199                   errors::InvalidArgument(
1200                       "Fused Conv2D must have one extra argument: bias."));
1201     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
1202       this->set_fuse_biasadd(true);
1203       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
1204                                 6.0);
1205       OP_REQUIRES(context, num_args == 1,
1206                   errors::InvalidArgument(
1207                       "Fused Conv2D must have one extra argument: bias."));
1208     } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
1209       this->set_fuse_biasadd(true);
1210       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
1211       OP_REQUIRES(context, num_args == 1,
1212                   errors::InvalidArgument(
1213                       "Fused Conv2D must have one extra argument: bias."));
1214     } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) {
1215       this->set_fuse_biasadd(true);
1216       float leakyrelu_alpha;
1217       OP_REQUIRES_OK(context,
1218                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
1219       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
1220                                 leakyrelu_alpha);
1221       OP_REQUIRES(context, num_args == 1,
1222                   errors::InvalidArgument(
1223                       "Fused Conv2D must have one extra argument: bias."));
1224     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) {
1225       this->set_fuse_biasadd(true);
1226       this->set_fuse_add(true);
1227       OP_REQUIRES(
1228           context, num_args == 2,
1229           errors::InvalidArgument(
1230               "Fused Conv2D must have two extra arguments: bias and add."));
1231     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
1232       this->set_fuse_biasadd(true);
1233       this->set_fuse_add(true);
1234       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
1235       OP_REQUIRES(
1236           context, num_args == 2,
1237           errors::InvalidArgument(
1238               "Fused Conv2D must have two extra arguments: bias and add."));
1239     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
1240       this->set_fuse_biasadd(true);
1241       this->set_fuse_add(true);
1242       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
1243                                 6.0);
1244       OP_REQUIRES(
1245           context, num_args == 2,
1246           errors::InvalidArgument(
1247               "Fused Conv2D must have two extra arguments: bias and add."));
1248     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
1249       this->set_fuse_biasadd(true);
1250       this->set_fuse_add(true);
1251       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
1252       OP_REQUIRES(
1253           context, num_args == 2,
1254           errors::InvalidArgument(
1255               "Fused Conv2D must have two extra arguments: bias and add."));
1256     } else if (fused_ops ==
1257                std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) {
1258       this->set_fuse_biasadd(true);
1259       this->set_fuse_add(true);
1260       float leakyrelu_alpha;
1261       OP_REQUIRES_OK(context,
1262                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
1263       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
1264                                 leakyrelu_alpha);
1265       OP_REQUIRES(
1266           context, num_args == 2,
1267           errors::InvalidArgument(
1268               "Fused Conv2D must have two extra arguments: bias and add."));
1269     } else {
1270       OP_REQUIRES(context, false,
1271                   errors::Unimplemented("Fusion is not implemented: [",
1272                                         absl::StrJoin(fused_ops, ","), "]"));
1273     }
1274 
1275     if (pad_enabled) {
1276       this->set_fuse_pad(true);
1277     }
1278   }
1279 
~MklFusedConvOp()1280   virtual ~MklFusedConvOp() {}
1281 };
1282 
1283 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
1284           typename Toutput, typename Ttemp_output, typename Tpadding,
1285           bool pad_enabled, bool bias_enabled, bool is_depthwise,
1286           bool native_format>
1287 class MklFusedDepthwiseConvOp
1288     : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1289                        Tpadding, bias_enabled, false, is_depthwise,
1290                        native_format> {
1291  public:
MklFusedDepthwiseConvOp(OpKernelConstruction * context)1292   explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context)
1293       : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1294                   Tpadding, bias_enabled, false, is_depthwise, native_format>(
1295             context) {
1296     // Since we came here through the registration of
1297     // _MklFusedDepthwiseConv2dNative, get all
1298     // information from 'fused_ops' and 'num_args'
1299     std::vector<string> fused_ops;
1300     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
1301 
1302     int num_args;
1303     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
1304     OP_REQUIRES(context, !fused_ops.empty(),
1305                 errors::InvalidArgument(
1306                     "Fused DepthwiseConv2D must have at least one fused op."));
1307 
1308     if (fused_ops == std::vector<string>{"BiasAdd"}) {
1309       this->set_fuse_biasadd(true);
1310     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
1311       this->set_fuse_biasadd(true);
1312       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
1313     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
1314       this->set_fuse_biasadd(true);
1315       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
1316                                 6.0);
1317     } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
1318       this->set_fuse_biasadd(true);
1319       this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
1320     } else {
1321       OP_REQUIRES(context, false,
1322                   errors::Unimplemented("Fusion is not implemented: [",
1323                                         absl::StrJoin(fused_ops, ","), "]"));
1324     }
1325 
1326     OP_REQUIRES(
1327         context, num_args == 1,
1328         errors::InvalidArgument(
1329             "Fused DepthwiseConv2D must have one extra argument: bias."));
1330 
1331     if (pad_enabled) {
1332       this->set_fuse_pad(true);
1333     }
1334   }
1335 
~MklFusedDepthwiseConvOp()1336   virtual ~MklFusedDepthwiseConvOp() {}
1337 };
1338 
1339 // We create new class for each version of Quantized Convolution and inherit
1340 // from the FP32 version of the base class
1341 template <typename Device, typename Tinput, typename Tbias, typename Toutput,
1342           typename Ttemp_output, bool bias_enabled, bool is_depthwise>
1343 class MklQuantizedConv2DOp
1344     : public MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output,
1345                        int32, bias_enabled, false, is_depthwise, false> {
1346  public:
~MklQuantizedConv2DOp()1347   virtual ~MklQuantizedConv2DOp() {
1348     if (this->input_bias_ != nullptr) {
1349       delete this->input_bias_;
1350       input_bias_ = nullptr;
1351     }
1352 
1353     if (this->scaled_bias_ != nullptr) {
1354       delete this->scaled_bias_;
1355       scaled_bias_ = nullptr;
1356     }
1357   }
1358 
MklQuantizedConv2DOp(OpKernelConstruction * context)1359   explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
1360       : MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
1361                   bias_enabled, false, is_depthwise, false>(context) {
1362     bool is_filter_const;
1363     OP_REQUIRES_OK(context,
1364                    context->GetAttr("is_filter_const", &is_filter_const));
1365 
1366     if (bias_enabled) {
1367       OP_REQUIRES_OK(context,
1368                      context->GetAttr("is_bias_const", &is_bias_const_));
1369     }
1370 
1371     OP_REQUIRES(context, is_filter_const,
1372                 errors::InvalidArgument("Filter must be a constant"));
1373   }
1374 
Compute(OpKernelContext * context)1375   void Compute(OpKernelContext* context) override {
1376     // Compute int32 output tensor
1377     MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
1378               bias_enabled, false, is_depthwise, false>::Compute(context);
1379 
1380     // Compute additional outputs: min/max scalars.
1381     int bias_index_offset;
1382     bias_index_offset = bias_enabled ? 1 : 0;
1383 
1384     const float min_input =
1385         context->input(2 + bias_index_offset).flat<float>()(0);
1386     const float max_input =
1387         context->input(3 + bias_index_offset).flat<float>()(0);
1388 
1389     MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
1390     output_min_mkl_shape.SetMklTensor(false);
1391     output_max_mkl_shape.SetMklTensor(false);
1392 
1393     Tensor* output_min = nullptr;
1394     Tensor* output_max = nullptr;
1395     if (std::is_same<Toutput, quint8>::value ||
1396         std::is_same<Toutput, qint8>::value) {
1397       AllocateOutputSetMklShape(context, 1, &output_min, {},
1398                                 output_min_mkl_shape);
1399       AllocateOutputSetMklShape(context, 2, &output_max, {},
1400                                 output_max_mkl_shape);
1401       // This is the case the convolution and requantization are fused.
1402       output_min->flat<float>()(0) =
1403           context->input(6 + bias_index_offset).flat<float>()(0);
1404       output_max->flat<float>()(0) =
1405           context->input(7 + bias_index_offset).flat<float>()(0);
1406     } else {
1407       const Tensor& min_filter = context->input(4 + bias_index_offset);
1408       const Tensor& max_filter = context->input(5 + bias_index_offset);
1409       if (min_filter.dims() == 0) {
1410         float min_output_value;
1411         float max_output_value;
1412         MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>(
1413             min_input, max_input, min_filter.flat<float>()(0),
1414             max_filter.flat<float>()(0), &min_output_value, &max_output_value);
1415         AllocateOutputSetMklShape(context, 1, &output_min, {},
1416                                   output_min_mkl_shape);
1417         AllocateOutputSetMklShape(context, 2, &output_max, {},
1418                                   output_max_mkl_shape);
1419         output_min->flat<float>()(0) = min_output_value;
1420         output_max->flat<float>()(0) = max_output_value;
1421       } else {
1422         size_t depth = min_filter.NumElements();
1423         AllocateOutputSetMklShape(context, 1, &output_min,
1424                                   {static_cast<ptrdiff_t>(depth)},
1425                                   output_min_mkl_shape);
1426         AllocateOutputSetMklShape(context, 2, &output_max,
1427                                   {static_cast<ptrdiff_t>(depth)},
1428                                   output_max_mkl_shape);
1429         MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>(
1430             min_input, max_input, min_filter, max_filter, &output_min,
1431             &output_max);
1432       }
1433     }
1434   }
1435 
1436  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1437   void ExtendConvFwdParams(OpKernelContext* context,
1438                            MklConvFwdParams& params) override {
1439     MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
1440               bias_enabled, false, is_depthwise,
1441               false>::ExtendConvFwdParams(context, params);
1442 
1443     // When the output type is quint8, the output data id requantized
1444     // into quint8. A post_op "output_scale" is added to do the conversion.
1445     if (std::is_same<Toutput, quint8>::value ||
1446         std::is_same<Toutput, qint8>::value) {
1447       int bias_index_offset;
1448       bias_index_offset = bias_enabled ? 1 : 0;
1449 
1450       const float min_input =
1451           context->input(2 + bias_index_offset).flat<float>()(0);
1452       const float max_input =
1453           context->input(3 + bias_index_offset).flat<float>()(0);
1454       const Tensor& min_filter_vector = context->input(4 + bias_index_offset);
1455       const Tensor& max_filter_vector = context->input(5 + bias_index_offset);
1456 
1457       // min_freezed_output and max_freezed_output are the actual range
1458       // for the output.
1459       const float min_freezed_output =
1460           context->input(6 + bias_index_offset).flat<float>()(0);
1461       const float max_freezed_output =
1462           context->input(7 + bias_index_offset).flat<float>()(0);
1463 
1464       float int_output_limit =
1465           std::is_same<Toutput, quint8>::value ? 255.0f : 127.0f;
1466       size_t depth = min_filter_vector.NumElements();
1467       const float* min_filter = min_filter_vector.flat<float>().data();
1468       const float* max_filter = max_filter_vector.flat<float>().data();
1469       std::vector<float> scales(depth);
1470       float float_input_range =
1471           std::max(std::abs(min_input), std::abs(max_input));
1472       float float_output_range =
1473           std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
1474       const float int_const_scale_limit =
1475           (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
1476       for (size_t i = 0; i < depth; ++i) {
1477         // For simplicity and symmetry, we set filter range to be outer
1478         // bounds of min_filter and max_filter.
1479         float float_filter_range =
1480             std::max(std::abs(min_filter[i]), std::abs(max_filter[i]));
1481         // To understand the scaling, please see mkl_requantize_ops_test.
1482         scales[i] = int_output_limit * float_input_range * float_filter_range /
1483                     (int_const_scale_limit * float_output_range);
1484       }
1485       // we are creating a partial key here to use with primitive key caching to
1486       // improve key creation performance. Instead of using actual values we are
1487       // using the pointers for min/max_filter_vector, and this works since the
1488       // filter vector here is a constant.
1489       FactoryKeyCreator param_key;
1490       param_key.AddAsKey<float>(min_input);
1491       param_key.AddAsKey<float>(max_input);
1492       param_key.AddAsKey<float>(min_freezed_output);
1493       param_key.AddAsKey<float>(max_freezed_output);
1494       param_key.AddAsKey<const float*>(min_filter);
1495       param_key.AddAsKey<const float*>(max_filter);
1496       params.post_op_params.push_back({"output_scale", mkldnn::algorithm::undef,
1497                                        scales, param_key.GetKey()});
1498     }
1499   }
1500 
GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv_fwd_pd,const Tensor & bias_tensor)1501   Tbias* GetBiasHandle(OpKernelContext* context,
1502                        std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1503                        const Tensor& bias_tensor) override {
1504     if (!bias_enabled) {
1505       return nullptr;
1506     }
1507     if (std::is_same<Tbias, qint32>::value) {
1508       return static_cast<Tbias*>(
1509           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1510     }
1511     int bias_index_offset;
1512     bias_index_offset = bias_enabled ? 1 : 0;
1513 
1514     const float min_input =
1515         context->input(2 + bias_index_offset).flat<float>()(0);
1516     const float max_input =
1517         context->input(3 + bias_index_offset).flat<float>()(0);
1518     const Tensor& min_filter_vector = context->input(4 + bias_index_offset);
1519     const Tensor& max_filter_vector = context->input(5 + bias_index_offset);
1520     const float* min_filter = min_filter_vector.flat<float>().data();
1521     const float* max_filter = max_filter_vector.flat<float>().data();
1522 
1523     const float int_const_scale_limit =
1524         (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
1525     // Re-scale bias if either of following 2 conditions are met:
1526     // 1. Bias is not const;
1527     // 2. Bias is const, but bias cache is empty (first iteration).
1528 
1529     size_t depth = min_filter_vector.NumElements();
1530     bool scales_are_valid = (depth == scales_.size());
1531     scales_.resize(depth);
1532     for (size_t i = 0; i < depth; ++i) {
1533       float tmp_scale =
1534           int_const_scale_limit /
1535           (std::max(std::abs(max_input), std::abs(min_input)) *
1536            std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
1537       if (scales_are_valid && std::abs(tmp_scale - scales_[i]) > 1e-6) {
1538         scales_are_valid = false;
1539       }
1540       scales_[i] = tmp_scale;
1541     }
1542     if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) {
1543       mkldnn::primitive_attr bias_attr;
1544       if (depth == 1) {
1545         bias_attr.set_output_scales(0, scales_);
1546       } else {
1547         bias_attr.set_output_scales(1, scales_);
1548       }
1549 
1550       auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
1551                                   MklDnnType<Tbias>(), memory::format_tag::x);
1552       void* bias_buf = static_cast<void*>(
1553           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1554       if (!input_bias_) {
1555         input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf);
1556       } else {
1557         input_bias_->set_data_handle(bias_buf);
1558       }
1559 
1560       if (!scaled_bias_buf_)
1561         AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
1562                               conv_fwd_pd->bias_desc(), &scaled_bias_buf_);
1563       if (!scaled_bias_) {
1564         scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_);
1565       } else {
1566         scaled_bias_->set_data_handle(scaled_bias_buf_);
1567       }
1568       auto reorder_desc =
1569           ReorderPd(this->cpu_engine_, input_bias_->get_desc(),
1570                     this->cpu_engine_, scaled_bias_->get_desc(), bias_attr);
1571       CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
1572                               this->cpu_engine_, context);
1573 
1574       Tbias* bias_data =
1575           reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
1576       if (is_bias_const_)
1577         CacheBias(context, conv_fwd_pd, bias_data, scaled_bias_);
1578 
1579       return bias_data;
1580     }
1581     return GetCachedBias(context);
1582   }
1583 
1584   bool is_bias_const_;
1585   PersistentTensor cached_bias_data_ptensor_ TF_GUARDED_BY(bias_cache_mu_);
1586 
1587   memory* input_bias_ = nullptr;
1588   memory* scaled_bias_ = nullptr;
1589 
1590   Tensor scaled_bias_tensor_;
1591   void* scaled_bias_buf_ = nullptr;
1592 
1593  private:
1594   std::vector<float> scales_;
1595   mutex bias_cache_mu_;
1596   // Allocate persistent tensors for cached bias data and
1597   // cached bias memory descriptor (data format)
AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** bias_tensor)1598   void AllocatePersistentTensor(OpKernelContext* context,
1599                                 const ConvFwdPd& conv_prim_desc,
1600                                 Tensor** bias_tensor) {
1601     DCHECK(bias_tensor);
1602     TensorShape bias_tf_shape;
1603     bias_tf_shape.AddDim(
1604         (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias)));
1605     OP_REQUIRES_OK(context, context->allocate_persistent(
1606                                 DataTypeToEnum<Tbias>::value, bias_tf_shape,
1607                                 &cached_bias_data_ptensor_, bias_tensor));
1608   }
1609 
1610   // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
1611   // be acquired before entering the function, since it is acquired
1612   // inside the function.
IsBiasCacheEmpty(OpKernelContext * context)1613   inline bool IsBiasCacheEmpty(OpKernelContext* context)
1614       TF_LOCKS_EXCLUDED(bias_cache_mu_) {
1615     tf_shared_lock lock(bias_cache_mu_);
1616     return (cached_bias_data_ptensor_.NumElements() == 0);
1617   }
1618 
1619   // Cache the converted bias in a persistent tensor.
1620   // Only one thread can execute this method at any given time.
CacheBias(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tbias * bias_data,const memory * scaled_bias)1621   void CacheBias(OpKernelContext* context,
1622                  const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1623                  Tbias* bias_data, const memory* scaled_bias)
1624       TF_LOCKS_EXCLUDED(bias_cache_mu_) {
1625     mutex_lock lock(bias_cache_mu_);
1626 
1627     // If bias is already cached, there's nothing to do.
1628     if (cached_bias_data_ptensor_.NumElements() > 0) {
1629       return;
1630     }
1631 
1632     // Otherwise, cache bias
1633     Tensor* bias_tensor_ptr = nullptr;
1634     AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
1635     void* cached_bias_data = const_cast<void*>(
1636         static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
1637     size_t cached_bias_data_size = scaled_bias->get_desc().get_size();
1638     memcpy(cached_bias_data, bias_data, cached_bias_data_size);
1639   }
1640 
GetCachedBias(OpKernelContext * context)1641   Tbias* GetCachedBias(OpKernelContext* context)
1642       TF_LOCKS_EXCLUDED(bias_cache_mu_) {
1643     tf_shared_lock lock(bias_cache_mu_);
1644     const Tensor& cached_bias_data =
1645         *cached_bias_data_ptensor_.AccessTensor(context);
1646 
1647     return static_cast<Tbias*>(
1648         const_cast<Tbias*>(cached_bias_data.flat<Tbias>().data()));
1649   }
1650 };
1651 
1652 template <typename Device, typename Tinput, typename Tbias, typename Toutput,
1653           typename Ttemp_output, bool bias_enabled, bool is_depthwise>
1654 class MklQuantizedConv2DReluOp
1655     : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1656                                   bias_enabled, is_depthwise> {
1657  public:
~MklQuantizedConv2DReluOp()1658   virtual ~MklQuantizedConv2DReluOp() {}
1659 
MklQuantizedConv2DReluOp(OpKernelConstruction * context)1660   explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context)
1661       : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1662                              bias_enabled, is_depthwise>(context) {}
1663 
1664  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1665   void ExtendConvFwdParams(OpKernelContext* context,
1666                            MklConvFwdParams& params) override {
1667     MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1668                          bias_enabled,
1669                          is_depthwise>::ExtendConvFwdParams(context, params);
1670 
1671     params.post_op_params.push_back(
1672         {"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
1673   }
1674 };
1675 
1676 template <typename Device, typename Tinput, typename Tbias, typename Toutput,
1677           typename Ttemp_output, bool bias_enabled, bool is_depthwise>
1678 class MklQuantizedConv2DSumReluOp
1679     : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1680                                   bias_enabled, is_depthwise> {
1681  public:
~MklQuantizedConv2DSumReluOp()1682   virtual ~MklQuantizedConv2DSumReluOp() {}
1683 
MklQuantizedConv2DSumReluOp(OpKernelConstruction * context)1684   explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context)
1685       : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1686                              bias_enabled, is_depthwise>(context) {}
1687 
1688  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1689   void ExtendConvFwdParams(OpKernelContext* context,
1690                            MklConvFwdParams& params) override {
1691     MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1692                          bias_enabled,
1693                          is_depthwise>::ExtendConvFwdParams(context, params);
1694     // Calculate the scale (beta in mkldnn api term) for sum
1695     if (std::is_same<Toutput, quint8>::value) {
1696       int summand_idx = context->num_inputs() / 2 - 1 - 2;
1697       DataType summand_type = this->input_type(summand_idx);
1698       bool summand_condition =
1699           (summand_type == DT_QINT8) || (summand_type == DT_QUINT8);
1700       CHECK((summand_condition));
1701       int bias_index_offset = bias_enabled ? 1 : 0;
1702       const float min_freezed_output =
1703           context->input(6 + bias_index_offset).flat<float>()(0);
1704       const float max_freezed_output =
1705           context->input(7 + bias_index_offset).flat<float>()(0);
1706       const float min_freezed_summand =
1707           context->input(9 + bias_index_offset).flat<float>()(0);
1708       const float max_freezed_summand =
1709           context->input(10 + bias_index_offset).flat<float>()(0);
1710 
1711       float scale_output =
1712           std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
1713       float scale_summand = std::max(std::abs(min_freezed_summand),
1714                                      std::abs(max_freezed_summand));
1715       // if summand_type is also DT_QUINT8 as the scale_output,
1716       // the scaling factor of 255.0f cancels each other and thus is avoided.
1717       // If it is not then  it is DT_INT8 and is scaled appropriately.
1718       if (summand_type == DT_QUINT8) {
1719         params.post_op_params.push_back({"sum",
1720                                          mkldnn::algorithm::undef,
1721                                          {scale_summand / scale_output},
1722                                          ""});
1723       } else {
1724         params.post_op_params.push_back(
1725             {"sum",
1726              mkldnn::algorithm::undef,
1727              {255.0f * scale_summand / (scale_output * 127.0f)},
1728              ""});
1729       }
1730     } else {
1731       params.post_op_params.push_back(
1732           {"sum", mkldnn::algorithm::undef, {1.0}, ""});
1733     }
1734     params.post_op_params.push_back(
1735         {"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
1736   }
1737 
AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)1738   void AllocateOutputTensor(OpKernelContext* context,
1739                             const ConvFwdPd& conv_prim_desc,
1740                             const memory::dims& output_dims_mkl_order,
1741                             MklTensorFormat output_tf_format,
1742                             MklDnnShape* output_mkl_shape,
1743                             Tensor** output_tensor) override {
1744     int summand_idx = context->num_inputs() / 2 - 1;
1745     if (std::is_same<Toutput, quint8>::value) {
1746       summand_idx -= 2;
1747       DataType summand_type = this->input_type(summand_idx);
1748       bool summand_condition =
1749           (summand_type == DT_QINT8) || (summand_type == DT_QUINT8);
1750       CHECK((summand_condition));
1751       Tensor& summand = const_cast<Tensor&>(MklGetInput(context, summand_idx));
1752       MklDnnShape summand_mkl_shape;
1753       GetMklShape(context, summand_idx, &summand_mkl_shape);
1754       auto dst_md = summand_mkl_shape.GetMklLayout();
1755 
1756       // TODO(intel-tf): Handle both non-MKL and MKL tensors
1757       if (summand_type == DT_QINT8) {
1758         OP_REQUIRES_OK(
1759             context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape()));
1760         dst_md.data.data_type =
1761             static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
1762         summand_mkl_shape.SetMklLayout(&dst_md);
1763         summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
1764       }
1765       // TODO(intel-tf): Support cases when summand cannot be forwarded.
1766       OP_REQUIRES(
1767           context,
1768           ForwardMklTensorInToOutWithMklShape(
1769               context, summand_idx, 0, output_tensor, summand_mkl_shape, false),
1770           errors::InvalidArgument(
1771               "Summand cannot be forwarded in the current fusion."));
1772       return;
1773     }
1774     MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
1775               bias_enabled, false, false,
1776               false>::AllocateOutputTensor(context, conv_prim_desc,
1777                                            output_dims_mkl_order,
1778                                            output_tf_format, output_mkl_shape,
1779                                            output_tensor);
1780     const Tensor& summand = MklGetInput(context, summand_idx);
1781     if (summand.dtype() != DT_FLOAT)
1782       TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
1783                          "Current fusion requires summand to be float"));
1784     MklDnnShape summand_mkl_shape;
1785     GetMklShape(context, summand_idx, &summand_mkl_shape);
1786     // We need to compute scale for the summand
1787     int bias_index_offset = bias_enabled ? 1 : 0;
1788     const float min_input =
1789         context->input(2 + bias_index_offset).flat<float>()(0);
1790     const float max_input =
1791         context->input(3 + bias_index_offset).flat<float>()(0);
1792     const Tensor& min_filter_vector = context->input(4 + bias_index_offset);
1793     const Tensor& max_filter_vector = context->input(5 + bias_index_offset);
1794     const float* min_filter = min_filter_vector.flat<float>().data();
1795     const float* max_filter = max_filter_vector.flat<float>().data();
1796 
1797     const float int_const_scale_limit =
1798         (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
1799     size_t depth = min_filter_vector.NumElements();
1800     std::vector<float> scales(depth);
1801     for (size_t i = 0; i < depth; ++i) {
1802       // TODO(nammbash): scale factors for UINT8(inputs) & INT8(weights) are
1803       // done regularly. A Cleaner design to address all mapping in one
1804       // function needs to be implemented in future which also supports other
1805       // quantized type mapping in future.
1806       scales[i] = int_const_scale_limit /
1807                   (std::max(std::abs(max_input), std::abs(min_input)) *
1808                    std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
1809     }
1810     mkldnn::primitive_attr reorder_attr;
1811     if (depth == 1) {
1812       reorder_attr.set_output_scales(0, scales);
1813     } else {
1814       reorder_attr.set_output_scales(2, scales);
1815     }
1816     auto summand_md =
1817         summand_mkl_shape.IsMklTensor()
1818             ? summand_mkl_shape.GetMklLayout()
1819             : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
1820                            memory::format_tag::nhwc);
1821     void* summand_buf =
1822         static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
1823     void* dst_buf =
1824         static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
1825     summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf));
1826     dst_.reset(
1827         new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf));
1828     auto reorder_desc =
1829         ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_,
1830                   conv_prim_desc.dst_desc(), reorder_attr);
1831     CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
1832                             context);
1833   }
1834 
1835   std::shared_ptr<mkldnn::memory> summand_;
1836   std::shared_ptr<mkldnn::memory> dst_;
1837 };
1838 
1839 // INT8 kernel registration
1840 // Register NoOp kernel for QuantizedConv2D for qint8 filter
1841 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2D")
1842                             .Device(DEVICE_CPU)
1843                             .TypeConstraint<quint8>("Tinput")
1844                             .TypeConstraint<qint8>("Tfilter")
1845                             .TypeConstraint<qint32>("out_type"),
1846                         NoOp);
1847 
1848 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize")
1849                             .Device(DEVICE_CPU)
1850                             .TypeConstraint<quint8>("Tinput")
1851                             .TypeConstraint<qint8>("Tfilter")
1852                             .TypeConstraint<qint8>("out_type"),
1853                         NoOp);
1854 
1855 // Register NoOp kernel for QuantizedConv2DPerChannel.
1856 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DPerChannel")
1857                             .Device(DEVICE_CPU)
1858                             .TypeConstraint<quint8>("Tinput")
1859                             .TypeConstraint<qint8>("Tfilter")
1860                             .TypeConstraint<qint32>("out_type"),
1861                         NoOp);
1862 // Register a templatized implementation of MklQuantizedConv2DPerChannel.
1863 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DPerChannel")
1864                             .Device(DEVICE_CPU)
1865                             .TypeConstraint<quint8>("Tinput")
1866                             .TypeConstraint<qint8>("Tfilter")
1867                             .TypeConstraint<qint32>("out_type")
1868                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
1869                         MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
1870                                              qint32, false, false>);
1871 
1872 // Register a templatized implementation of MklQuantizedConv2D.
1873 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2D")
1874                             .Device(DEVICE_CPU)
1875                             .TypeConstraint<quint8>("Tinput")
1876                             .TypeConstraint<qint8>("Tfilter")
1877                             .TypeConstraint<qint32>("out_type")
1878                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
1879                         MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
1880                                              qint32, false, false>);
1881 
1882 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2D")
1883                             .Device(DEVICE_CPU)
1884                             .TypeConstraint<qint8>("Tinput")
1885                             .TypeConstraint<qint8>("Tfilter")
1886                             .TypeConstraint<qint32>("out_type")
1887                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
1888                         MklQuantizedConv2DOp<CPUDevice, qint8, float, qint32,
1889                                              qint32, false, false>);
1890 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndRequantize")
1891                             .Device(DEVICE_CPU)
1892                             .TypeConstraint<quint8>("Tinput")
1893                             .TypeConstraint<qint8>("Tfilter")
1894                             .TypeConstraint<qint8>("out_type")
1895                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
1896                         MklQuantizedConv2DOp<CPUDevice, quint8, qint32, qint8,
1897                                              qint8, false, false>);
1898 
1899 // Register NoOp kernel for QuantizedConv2DWithBias to get a python interface.
1900 // This kernel will be replaced by an MKL kernel during graph
1901 // optimization pass.
1902 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBias")
1903                             .Device(DEVICE_CPU)
1904                             .TypeConstraint<quint8>("Tinput")
1905                             .TypeConstraint<qint8>("Tfilter")
1906                             .TypeConstraint<qint32>("out_type"),
1907                         NoOp);
1908 
1909 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize")
1910                             .Device(DEVICE_CPU)
1911                             .TypeConstraint<quint8>("Tinput")
1912                             .TypeConstraint<qint8>("Tfilter")
1913                             .TypeConstraint<qint8>("out_type"),
1914                         NoOp);
1915 
1916 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBias")
1917                             .Device(DEVICE_CPU)
1918                             .TypeConstraint<qint8>("Tinput")
1919                             .TypeConstraint<qint8>("Tfilter")
1920                             .TypeConstraint<qint32>("out_type"),
1921                         NoOp);
1922 
1923 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize")
1924                             .Device(DEVICE_CPU)
1925                             .TypeConstraint<qint8>("Tinput")
1926                             .TypeConstraint<qint8>("Tfilter")
1927                             .TypeConstraint<qint8>("out_type"),
1928                         NoOp);
1929 // Register a templatized implementation MklQuantizedConv2DWithBias.
1930 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBias")
1931                             .Device(DEVICE_CPU)
1932                             .TypeConstraint<quint8>("Tinput")
1933                             .TypeConstraint<qint8>("Tfilter")
1934                             .TypeConstraint<qint32>("out_type")
1935                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
1936                         MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
1937                                              qint32, true, false>);
1938 
1939 REGISTER_KERNEL_BUILDER(
1940     Name("_MklQuantizedConv2DWithBiasAndRequantize")
1941         .Device(DEVICE_CPU)
1942         .TypeConstraint<quint8>("Tinput")
1943         .TypeConstraint<qint8>("Tfilter")
1944         .TypeConstraint<qint32>("Tbias")
1945         .TypeConstraint<qint8>("out_type")
1946         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1947     MklQuantizedConv2DOp<CPUDevice, quint8, qint32, qint8, qint8, true, false>);
1948 
1949 REGISTER_KERNEL_BUILDER(
1950     Name("_MklQuantizedConv2DWithBiasAndRequantize")
1951         .Device(DEVICE_CPU)
1952         .TypeConstraint<quint8>("Tinput")
1953         .TypeConstraint<qint8>("Tfilter")
1954         .TypeConstraint<float>("Tbias")
1955         .TypeConstraint<qint8>("out_type")
1956         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1957     MklQuantizedConv2DOp<CPUDevice, quint8, float, qint8, qint8, true, false>);
1958 
1959 REGISTER_KERNEL_BUILDER(
1960     Name("_MklQuantizedConv2DWithBias")
1961         .Device(DEVICE_CPU)
1962         .TypeConstraint<qint8>("Tinput")
1963         .TypeConstraint<qint8>("Tfilter")
1964         .TypeConstraint<qint32>("out_type")
1965         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1966     MklQuantizedConv2DOp<CPUDevice, qint8, float, qint32, qint32, true, false>);
1967 
1968 REGISTER_KERNEL_BUILDER(
1969     Name("_MklQuantizedConv2DWithBiasAndRequantize")
1970         .Device(DEVICE_CPU)
1971         .TypeConstraint<qint8>("Tinput")
1972         .TypeConstraint<qint8>("Tfilter")
1973         .TypeConstraint<qint32>("Tbias")
1974         .TypeConstraint<qint8>("out_type")
1975         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1976     MklQuantizedConv2DOp<CPUDevice, qint8, qint32, qint8, qint8, true, false>);
1977 
1978 REGISTER_KERNEL_BUILDER(
1979     Name("_MklQuantizedConv2DWithBiasAndRequantize")
1980         .Device(DEVICE_CPU)
1981         .TypeConstraint<qint8>("Tinput")
1982         .TypeConstraint<qint8>("Tfilter")
1983         .TypeConstraint<float>("Tbias")
1984         .TypeConstraint<qint8>("out_type")
1985         .Label(mkl_op_registry::kMklQuantizedOpLabel),
1986     MklQuantizedConv2DOp<CPUDevice, qint8, float, qint8, qint8, true, false>);
1987 
1988 // Register NoOp kernel for QuantizedConv2DAndRelu to get a python interface.
1989 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
1990 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRelu")
1991                             .Device(DEVICE_CPU)
1992                             .TypeConstraint<quint8>("Tinput")
1993                             .TypeConstraint<qint8>("Tfilter")
1994                             .TypeConstraint<qint32>("out_type"),
1995                         NoOp);
1996 
1997 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndReluAndRequantize")
1998                             .Device(DEVICE_CPU)
1999                             .TypeConstraint<quint8>("Tinput")
2000                             .TypeConstraint<qint8>("Tfilter")
2001                             .TypeConstraint<quint8>("out_type"),
2002                         NoOp);
2003 
2004 // Register a templatized implementation of MklQuantizedConv2DAndRelu.
2005 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndRelu")
2006                             .Device(DEVICE_CPU)
2007                             .TypeConstraint<quint8>("Tinput")
2008                             .TypeConstraint<qint8>("Tfilter")
2009                             .TypeConstraint<qint32>("out_type")
2010                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2011                         MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
2012                                                  qint32, qint32, false, false>);
2013 
2014 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndReluAndRequantize")
2015                             .Device(DEVICE_CPU)
2016                             .TypeConstraint<quint8>("Tinput")
2017                             .TypeConstraint<qint8>("Tfilter")
2018                             .TypeConstraint<quint8>("out_type")
2019                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2020                         MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32,
2021                                                  quint8, quint8, false, false>);
2022 
2023 // Register NoOp kernel for QuantizedConv2DWithBiasAndRelu to get a python
2024 // interface.
2025 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
2026 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu")
2027                             .Device(DEVICE_CPU)
2028                             .TypeConstraint<quint8>("Tinput")
2029                             .TypeConstraint<qint8>("Tfilter")
2030                             .TypeConstraint<qint32>("out_type"),
2031                         NoOp);
2032 
2033 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu")
2034                             .Device(DEVICE_CPU)
2035                             .TypeConstraint<qint8>("Tinput")
2036                             .TypeConstraint<qint8>("Tfilter")
2037                             .TypeConstraint<qint32>("out_type"),
2038                         NoOp);
2039 
2040 // Register NoOp kernel for QuantizedConv2DWithBiasAndReluAndRequantize
2041 // to get a python interface.
2042 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
2043 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize")
2044                             .Device(DEVICE_CPU)
2045                             .TypeConstraint<quint8>("Tinput")
2046                             .TypeConstraint<qint8>("Tfilter")
2047                             .TypeConstraint<quint8>("out_type"),
2048                         NoOp);
2049 
2050 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize")
2051                             .Device(DEVICE_CPU)
2052                             .TypeConstraint<qint8>("Tinput")
2053                             .TypeConstraint<qint8>("Tfilter")
2054                             .TypeConstraint<quint8>("out_type"),
2055                         NoOp);
2056 // Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu.
2057 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndRelu")
2058                             .Device(DEVICE_CPU)
2059                             .TypeConstraint<quint8>("Tinput")
2060                             .TypeConstraint<qint8>("Tfilter")
2061                             .TypeConstraint<qint32>("out_type")
2062                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2063                         MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
2064                                                  qint32, qint32, true, false>);
2065 
2066 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndRelu")
2067                             .Device(DEVICE_CPU)
2068                             .TypeConstraint<qint8>("Tinput")
2069                             .TypeConstraint<qint8>("Tfilter")
2070                             .TypeConstraint<qint32>("out_type")
2071                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2072                         MklQuantizedConv2DReluOp<CPUDevice, qint8, float,
2073                                                  qint32, qint32, true, false>);
2074 // Register a templatized implementation of
2075 // MklQuantizedConv2DWithBiasAndReluAndRequantize.
2076 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
2077                             .Device(DEVICE_CPU)
2078                             .TypeConstraint<quint8>("Tinput")
2079                             .TypeConstraint<qint8>("Tfilter")
2080                             .TypeConstraint<float>("Tbias")
2081                             .TypeConstraint<quint8>("out_type")
2082                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2083                         MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
2084                                                  quint8, quint8, true, false>);
2085 
2086 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
2087                             .Device(DEVICE_CPU)
2088                             .TypeConstraint<quint8>("Tinput")
2089                             .TypeConstraint<qint8>("Tfilter")
2090                             .TypeConstraint<qint32>("Tbias")
2091                             .TypeConstraint<quint8>("out_type")
2092                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2093                         MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32,
2094                                                  quint8, quint8, true, false>);
2095 
2096 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
2097                             .Device(DEVICE_CPU)
2098                             .TypeConstraint<qint8>("Tinput")
2099                             .TypeConstraint<qint8>("Tfilter")
2100                             .TypeConstraint<float>("Tbias")
2101                             .TypeConstraint<quint8>("out_type")
2102                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2103                         MklQuantizedConv2DReluOp<CPUDevice, qint8, float,
2104                                                  quint8, quint8, true, false>);
2105 
2106 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
2107                             .Device(DEVICE_CPU)
2108                             .TypeConstraint<qint8>("Tinput")
2109                             .TypeConstraint<qint8>("Tfilter")
2110                             .TypeConstraint<qint32>("Tbias")
2111                             .TypeConstraint<quint8>("out_type")
2112                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2113                         MklQuantizedConv2DReluOp<CPUDevice, qint8, qint32,
2114                                                  quint8, quint8, true, false>);
2115 
2116 // Register NoOp kernel for QuantizedConv2DWithBiasSumAndRelu to get a python
2117 // interface.
2118 // This kernel will be replaced by an MKL kernel during graph-optimization pass.
2119 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndRelu")
2120                             .Device(DEVICE_CPU)
2121                             .TypeConstraint<quint8>("Tinput")
2122                             .TypeConstraint<qint8>("Tfilter")
2123                             .TypeConstraint<qint32>("out_type"),
2124                         NoOp);
2125 
2126 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndReluAndRequantize")
2127                             .Device(DEVICE_CPU)
2128                             .TypeConstraint<quint8>("Tinput")
2129                             .TypeConstraint<qint8>("Tfilter")
2130                             .TypeConstraint<quint8>("out_type"),
2131                         NoOp);
2132 
2133 REGISTER_KERNEL_BUILDER(
2134     Name("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2135         .Device(DEVICE_CPU)
2136         .TypeConstraint<quint8>("Tinput")
2137         .TypeConstraint<qint8>("Tfilter")
2138         .TypeConstraint<quint8>("out_type"),
2139     NoOp);
2140 
2141 // Register a templatized implementation of
2142 // MklQuantizedConv2DWithBiasSumAndRelu.
2143 REGISTER_KERNEL_BUILDER(
2144     Name("_MklQuantizedConv2DWithBiasSumAndRelu")
2145         .Device(DEVICE_CPU)
2146         .TypeConstraint<quint8>("Tinput")
2147         .TypeConstraint<qint8>("Tfilter")
2148         .TypeConstraint<qint32>("out_type")
2149         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2150     MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, qint32, qint32, true,
2151                                 false>);
2152 
2153 REGISTER_KERNEL_BUILDER(
2154     Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
2155         .Device(DEVICE_CPU)
2156         .TypeConstraint<quint8>("Tinput")
2157         .TypeConstraint<qint8>("Tfilter")
2158         .TypeConstraint<qint32>("Tbias")
2159         .TypeConstraint<quint8>("out_type")
2160         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2161     MklQuantizedConv2DSumReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
2162                                 false>);
2163 
2164 REGISTER_KERNEL_BUILDER(
2165     Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2166         .Device(DEVICE_CPU)
2167         .TypeConstraint<quint8>("Tinput")
2168         .TypeConstraint<qint8>("Tfilter")
2169         .TypeConstraint<qint32>("Tbias")
2170         .TypeConstraint<quint8>("out_type")
2171         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2172     MklQuantizedConv2DSumReluOp<CPUDevice, quint8, qint32, quint8, qint8, true,
2173                                 false>);
2174 
2175 REGISTER_KERNEL_BUILDER(
2176     Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
2177         .Device(DEVICE_CPU)
2178         .TypeConstraint<quint8>("Tinput")
2179         .TypeConstraint<qint8>("Tfilter")
2180         .TypeConstraint<float>("Tbias")
2181         .TypeConstraint<quint8>("out_type")
2182         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2183     MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, quint8, quint8, true,
2184                                 false>);
2185 
2186 REGISTER_KERNEL_BUILDER(
2187     Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2188         .Device(DEVICE_CPU)
2189         .TypeConstraint<quint8>("Tinput")
2190         .TypeConstraint<qint8>("Tfilter")
2191         .TypeConstraint<float>("Tbias")
2192         .TypeConstraint<quint8>("out_type")
2193         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2194     MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, quint8, qint8, true,
2195                                 false>);
2196 
2197 // Register NoOp kernels for non-fused and fused versions of
2198 // QuantizedDepthwiseConv2D to get a Python interface. These kernels will be
2199 // replaced by MKL kernels during the graph-optimization pass.
2200 REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2D")
2201                             .Device(DEVICE_CPU)
2202                             .TypeConstraint<quint8>("Tinput")
2203                             .TypeConstraint<qint8>("Tfilter")
2204                             .TypeConstraint<qint32>("out_type"),
2205                         NoOp);
2206 
2207 REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2DWithBias")
2208                             .Device(DEVICE_CPU)
2209                             .TypeConstraint<quint8>("Tinput")
2210                             .TypeConstraint<qint8>("Tfilter")
2211                             .TypeConstraint<qint32>("out_type"),
2212                         NoOp);
2213 
2214 REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2DWithBiasAndRelu")
2215                             .Device(DEVICE_CPU)
2216                             .TypeConstraint<quint8>("Tinput")
2217                             .TypeConstraint<qint8>("Tfilter")
2218                             .TypeConstraint<qint32>("out_type"),
2219                         NoOp);
2220 
2221 REGISTER_KERNEL_BUILDER(
2222     Name("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
2223         .Device(DEVICE_CPU)
2224         .TypeConstraint<quint8>("Tinput")
2225         .TypeConstraint<qint8>("Tfilter")
2226         .TypeConstraint<quint8>("out_type"),
2227     NoOp);
2228 
2229 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
2230                             .Device(DEVICE_CPU)
2231                             .TypeConstraint<bfloat16>("T"),
2232                         NoOp);
2233 
2234 #define REGISTER_NO_OP_CPU_2D_DEPTHWISE(T)                    \
2235   REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative") \
2236                               .Device(DEVICE_CPU)             \
2237                               .TypeConstraint<T>("T"),        \
2238                           NoOp);
2239 
2240 TF_CALL_float(REGISTER_NO_OP_CPU_2D_DEPTHWISE);
2241 TF_CALL_bfloat16(REGISTER_NO_OP_CPU_2D_DEPTHWISE);
2242 
2243 // Register templatized MKL kernels for non-fused and fused-versions of
2244 // QuantizedDepthwiseConv2D.
2245 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D")
2246                             .Device(DEVICE_CPU)
2247                             .TypeConstraint<quint8>("Tinput")
2248                             .TypeConstraint<qint8>("Tfilter")
2249                             .TypeConstraint<qint32>("out_type")
2250                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2251                         MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
2252                                              qint32, false, true>);
2253 
2254 REGISTER_KERNEL_BUILDER(
2255     Name("_MklQuantizedDepthwiseConv2DWithBias")
2256         .Device(DEVICE_CPU)
2257         .TypeConstraint<quint8>("Tinput")
2258         .TypeConstraint<qint8>("Tfilter")
2259         .TypeConstraint<qint32>("out_type")
2260         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2261     MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32, qint32, true, true>);
2262 
2263 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2DWithBiasAndRelu")
2264                             .Device(DEVICE_CPU)
2265                             .TypeConstraint<quint8>("Tinput")
2266                             .TypeConstraint<qint8>("Tfilter")
2267                             .TypeConstraint<qint32>("out_type")
2268                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
2269                         MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
2270                                                  qint32, qint32, true, true>);
2271 
2272 // Tbias -> float
2273 REGISTER_KERNEL_BUILDER(
2274     Name("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
2275         .Device(DEVICE_CPU)
2276         .TypeConstraint<quint8>("Tinput")
2277         .TypeConstraint<qint8>("Tfilter")
2278         .TypeConstraint<float>("Tbias")
2279         .TypeConstraint<quint8>("out_type")
2280         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2281     MklQuantizedConv2DReluOp<CPUDevice, quint8, float, quint8, quint8, true,
2282                              true>);
2283 
2284 // Tbias -> qint32
2285 REGISTER_KERNEL_BUILDER(
2286     Name("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
2287         .Device(DEVICE_CPU)
2288         .TypeConstraint<quint8>("Tinput")
2289         .TypeConstraint<qint8>("Tfilter")
2290         .TypeConstraint<qint32>("Tbias")
2291         .TypeConstraint<quint8>("out_type")
2292         .Label(mkl_op_registry::kMklQuantizedOpLabel),
2293     MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
2294                              true>);
2295 
2296 // Register 2D operations
2297 #define REGISTER_MKL_CPU_2D(T)                                                 \
2298   REGISTER_KERNEL_BUILDER(                                                     \
2299       Name("_MklConv2D")                                                       \
2300           .Device(DEVICE_CPU)                                                  \
2301           .TypeConstraint<T>("T")                                              \
2302           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2303       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \
2304   REGISTER_KERNEL_BUILDER(                                                     \
2305       Name("_MklConv2DWithBias")                                               \
2306           .Device(DEVICE_CPU)                                                  \
2307           .TypeConstraint<T>("T")                                              \
2308           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2309       MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>);  \
2310   REGISTER_KERNEL_BUILDER(                                                     \
2311       Name("__MklDummyConv2DWithBias")                                         \
2312           .Device(DEVICE_CPU)                                                  \
2313           .TypeConstraint<T>("T")                                              \
2314           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2315       MklDummyOp<CPUDevice, T>);                                               \
2316   REGISTER_KERNEL_BUILDER(                                                     \
2317       Name("_MklPadWithConv2D")                                                \
2318           .Device(DEVICE_CPU)                                                  \
2319           .TypeConstraint<T>("T")                                              \
2320           .TypeConstraint<int32>("Tpaddings")                                  \
2321           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2322       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>);  \
2323   REGISTER_KERNEL_BUILDER(                                                     \
2324       Name("_MklPadWithConv2D")                                                \
2325           .Device(DEVICE_CPU)                                                  \
2326           .TypeConstraint<T>("T")                                              \
2327           .TypeConstraint<int64>("Tpaddings")                                  \
2328           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2329       MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>);  \
2330   REGISTER_KERNEL_BUILDER(                                                     \
2331       Name("__MklDummyPadWithConv2D")                                          \
2332           .Device(DEVICE_CPU)                                                  \
2333           .TypeConstraint<T>("T")                                              \
2334           .TypeConstraint<int32>("Tpaddings")                                  \
2335           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2336       MklDummyOp<CPUDevice, T>);                                               \
2337   REGISTER_KERNEL_BUILDER(                                                     \
2338       Name("_MklNativeConv2D")                                                 \
2339           .Device(DEVICE_CPU)                                                  \
2340           .TypeConstraint<T>("T")                                              \
2341           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2342       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>);  \
2343   REGISTER_KERNEL_BUILDER(                                                     \
2344       Name("_MklNativeConv2DWithBias")                                         \
2345           .Device(DEVICE_CPU)                                                  \
2346           .TypeConstraint<T>("T")                                              \
2347           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2348       MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, true>);   \
2349   REGISTER_KERNEL_BUILDER(                                                     \
2350       Name("_MklNativePadWithConv2D")                                          \
2351           .Device(DEVICE_CPU)                                                  \
2352           .TypeConstraint<T>("T")                                              \
2353           .TypeConstraint<int32>("Tpaddings")                                  \
2354           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2355       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, true>);   \
2356   REGISTER_KERNEL_BUILDER(                                                     \
2357       Name("_MklNativePadWithConv2D")                                          \
2358           .Device(DEVICE_CPU)                                                  \
2359           .TypeConstraint<T>("T")                                              \
2360           .TypeConstraint<int64>("Tpaddings")                                  \
2361           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2362       MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, true>);
2363 
2364 TF_CALL_float(REGISTER_MKL_CPU_2D);
2365 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
2366 
2367 #define REGISTER_MKL_CPU_2D_DEPTHWISE(T)                                      \
2368   REGISTER_KERNEL_BUILDER(                                                    \
2369       Name("_MklDepthwiseConv2dNative")                                       \
2370           .Device(DEVICE_CPU)                                                 \
2371           .TypeConstraint<T>("T")                                             \
2372           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                \
2373       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>); \
2374   REGISTER_KERNEL_BUILDER(                                                    \
2375       Name("_MklFusedDepthwiseConv2dNative")                                  \
2376           .Device(DEVICE_CPU)                                                 \
2377           .TypeConstraint<T>("T")                                             \
2378           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                \
2379       MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true,   \
2380                               true, false>);                                  \
2381   REGISTER_KERNEL_BUILDER(                                                    \
2382       Name("_MklNativeFusedDepthwiseConv2dNative")                            \
2383           .Device(DEVICE_CPU)                                                 \
2384           .TypeConstraint<T>("T")                                             \
2385           .Label(mkl_op_registry::kMklNameChangeOpLabel),                     \
2386       MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true,   \
2387                               true, true>);                                   \
2388   REGISTER_KERNEL_BUILDER(                                                    \
2389       Name("_MklNativeDepthwiseConv2dNative")                                 \
2390           .Device(DEVICE_CPU)                                                 \
2391           .TypeConstraint<T>("T")                                             \
2392           .Label(mkl_op_registry::kMklNameChangeOpLabel),                     \
2393       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, true>);
2394 
2395 TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
2396 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE);
2397 
2398 // Note we are registering _MklFusedConv2D.
2399 // We check the fused_ops attributes to decide if bias is enabled or not.
2400 #define REGISTER_MKL_CPU_2D_FUSED(T)                                  \
2401   REGISTER_KERNEL_BUILDER(                                            \
2402       Name("_MklFusedConv2D")                                         \
2403           .Device(DEVICE_CPU)                                         \
2404           .TypeConstraint<T>("T")                                     \
2405           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2406       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, false>); \
2407   REGISTER_KERNEL_BUILDER(                                            \
2408       Name("_MklPadWithFusedConv2D")                                  \
2409           .Device(DEVICE_CPU)                                         \
2410           .TypeConstraint<int32>("Tpaddings")                         \
2411           .TypeConstraint<T>("T")                                     \
2412           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2413       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, false>);  \
2414   REGISTER_KERNEL_BUILDER(                                            \
2415       Name("_MklPadWithFusedConv2D")                                  \
2416           .Device(DEVICE_CPU)                                         \
2417           .TypeConstraint<T>("T")                                     \
2418           .TypeConstraint<int64>("Tpaddings")                         \
2419           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2420       MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, false>);  \
2421   REGISTER_KERNEL_BUILDER(                                            \
2422       Name("__MklDummyPadWithFusedConv2D")                            \
2423           .Device(DEVICE_CPU)                                         \
2424           .TypeConstraint<T>("T")                                     \
2425           .TypeConstraint<int32>("Tpaddings")                         \
2426           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2427       MklDummyOp<CPUDevice, T>);                                      \
2428   REGISTER_KERNEL_BUILDER(                                            \
2429       Name("_MklNativeFusedConv2D")                                   \
2430           .Device(DEVICE_CPU)                                         \
2431           .TypeConstraint<T>("T")                                     \
2432           .Label(mkl_op_registry::kMklNameChangeOpLabel),             \
2433       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, true>);  \
2434   REGISTER_KERNEL_BUILDER(                                            \
2435       Name("_MklNativePadWithFusedConv2D")                            \
2436           .Device(DEVICE_CPU)                                         \
2437           .TypeConstraint<int32>("Tpaddings")                         \
2438           .TypeConstraint<T>("T")                                     \
2439           .Label(mkl_op_registry::kMklNameChangeOpLabel),             \
2440       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, true>);   \
2441   REGISTER_KERNEL_BUILDER(                                            \
2442       Name("_MklNativePadWithFusedConv2D")                            \
2443           .Device(DEVICE_CPU)                                         \
2444           .TypeConstraint<T>("T")                                     \
2445           .TypeConstraint<int64>("Tpaddings")                         \
2446           .Label(mkl_op_registry::kMklNameChangeOpLabel),             \
2447       MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, true>);
2448 
2449 TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED);
2450 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED);
2451 
2452 // Register 3D operations
2453 #define REGISTER_MKL_CPU_3D(T)                                                 \
2454   REGISTER_KERNEL_BUILDER(                                                     \
2455       Name("_MklConv3D")                                                       \
2456           .Device(DEVICE_CPU)                                                  \
2457           .TypeConstraint<T>("T")                                              \
2458           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2459       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \
2460   REGISTER_KERNEL_BUILDER(                                                     \
2461       Name("_MklNativeConv3D")                                                 \
2462           .Device(DEVICE_CPU)                                                  \
2463           .TypeConstraint<T>("T")                                              \
2464           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2465       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>);
2466 TF_CALL_float(REGISTER_MKL_CPU_3D);
2467 TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
2468 
2469 }  // namespace tensorflow
2470 #endif  // INTEL_MKL
2471