1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #ifdef INTEL_MKL 18 19 #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h" 20 21 #include <algorithm> 22 #include <map> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "mkldnn.hpp" 28 #include "absl/strings/str_join.h" 29 #include "tensorflow/core/framework/bounds_check.h" 30 #include "tensorflow/core/framework/numeric_op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/register_types.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/framework/tensor_slice.h" 36 #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h" 37 #include "tensorflow/core/kernels/no_op.h" 38 #include "tensorflow/core/kernels/ops_util.h" 39 #include "tensorflow/core/lib/core/errors.h" 40 #include "tensorflow/core/lib/gtl/array_slice.h" 41 #include "tensorflow/core/lib/strings/numbers.h" 42 #include "tensorflow/core/lib/strings/str_util.h" 43 #include "tensorflow/core/lib/strings/strcat.h" 44 #include "tensorflow/core/platform/logging.h" 45 #include "tensorflow/core/platform/macros.h" 46 #include "tensorflow/core/util/mkl_util.h" 47 #include "tensorflow/core/util/padding.h" 48 #include "tensorflow/core/util/tensor_format.h" 49 50 using mkldnn::convolution_forward; 51 using mkldnn::prop_kind; 52 using mkldnn::stream; 53 using ConvFwdPd = mkldnn::convolution_forward::primitive_desc; 54 using ReorderPd = mkldnn::reorder::primitive_desc; 55 56 namespace tensorflow { 57 // This structure aggregates multiple inputs to Conv2DFwd* methods. 58 struct MklConvFwdParams { 59 memory::dims src_dims; 60 memory::dims filter_dims; 61 memory::dims bias_dims; 62 memory::dims dst_dims; 63 memory::dims strides; 64 memory::dims dilations; 65 memory::dims padding_left; 66 memory::dims padding_right; 67 MklTensorFormat tf_fmt; 68 bool native_format; 69 string dtypes = string(""); 70 struct PostOpParam { 71 string name; 72 mkldnn::algorithm alg; 73 std::vector<float> param; 74 std::string partial_key; 75 }; 76 std::vector<PostOpParam> post_op_params; 77 MklConvFwdParamstensorflow::MklConvFwdParams78 MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims, 79 memory::dims bias_dims, memory::dims dst_dims, 80 memory::dims strides, memory::dims dilations, 81 memory::dims padding_left, memory::dims padding_right, 82 MklTensorFormat tf_fmt, bool native_format) 83 : src_dims(src_dims), 84 filter_dims(filter_dims), 85 bias_dims(bias_dims), 86 dst_dims(dst_dims), 87 strides(strides), 88 dilations(dilations), 89 padding_left(padding_left), 90 padding_right(padding_right), 91 tf_fmt(tf_fmt), 92 native_format(native_format) {} 93 }; 94 95 // With quantization, input, filter, and output can have different types 96 // so we use different template parameter for each type 97 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput> 98 class MklConvFwdPrimitive : public MklPrimitive { 99 public: MklConvFwdPrimitive(const MklConvFwdParams & convFwdDims)100 explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) 101 : MklPrimitive(engine(engine::kind::cpu, 0)) { 102 // Create convolution primitive 103 if (context_.conv_fwd == nullptr) { 104 Setup(convFwdDims); 105 } 106 } ~MklConvFwdPrimitive()107 ~MklConvFwdPrimitive() {} 108 109 // Convolution forward execute with bias 110 // src_data: input data buffer of src 111 // filter_data: input data buffer of filter (weights) 112 // bias_data: input data buffer of bias 113 // dst_data: output data buffer of dst Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream)114 void Execute(const Tinput* src_data, const Tfilter* filter_data, 115 const Tbias* bias_data, const Toutput* dst_data, 116 std::shared_ptr<stream> fwd_stream) { 117 #ifdef ENABLE_MKLDNN_THREADPOOL 118 // TODO: Create a common function and avoid the duplicate code 119 context_.src_mem->set_data_handle( 120 static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream); 121 context_.filter_mem->set_data_handle( 122 static_cast<void*>(const_cast<Tfilter*>(filter_data)), *fwd_stream); 123 if (bias_data != nullptr) { 124 context_.bias_mem->set_data_handle( 125 static_cast<void*>(const_cast<Tbias*>(bias_data)), *fwd_stream); 126 } 127 context_.dst_mem->set_data_handle( 128 static_cast<void*>(const_cast<Toutput*>(dst_data)), *fwd_stream); 129 #else 130 context_.src_mem->set_data_handle( 131 static_cast<void*>(const_cast<Tinput*>(src_data))); 132 context_.filter_mem->set_data_handle( 133 static_cast<void*>(const_cast<Tfilter*>(filter_data))); 134 if (bias_data != nullptr) { 135 context_.bias_mem->set_data_handle( 136 static_cast<void*>(const_cast<Tbias*>(bias_data))); 137 } 138 context_.dst_mem->set_data_handle( 139 static_cast<void*>(const_cast<Toutput*>(dst_data))); 140 #endif // ENABLE_MKLDNN_THREADPOOL 141 142 DCHECK_EQ(context_.fwd_primitives.size(), 143 context_.fwd_primitives_args.size()); 144 for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { 145 context_.fwd_primitives.at(i).execute(*fwd_stream, 146 context_.fwd_primitives_args.at(i)); 147 } 148 149 // After execution, set data handle back 150 context_.src_mem->set_data_handle(DummyData); 151 context_.filter_mem->set_data_handle(DummyData); 152 if (bias_data != nullptr) { 153 context_.bias_mem->set_data_handle(DummyData); 154 } 155 context_.dst_mem->set_data_handle(DummyData); 156 } 157 158 // Convolution forward execute without bias 159 // src_data: input data buffer of src 160 // filter_data: input data buffer of filter (weights) 161 // dst_data: output data buffer of dst Execute(const Tinput * src_data,const Tfilter * filter_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream)162 void Execute(const Tinput* src_data, const Tfilter* filter_data, 163 const Toutput* dst_data, std::shared_ptr<stream> fwd_stream) { 164 Execute(src_data, filter_data, nullptr, dst_data, fwd_stream); 165 } 166 GetPrimitiveDesc() const167 std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const { 168 return context_.fwd_pd; 169 } 170 171 private: 172 // Primitive reuse context for Conv2D Fwd op 173 struct ConvFwdContext { 174 // MKL-DNN memory 175 std::shared_ptr<mkldnn::memory> src_mem; 176 std::shared_ptr<mkldnn::memory> filter_mem; 177 std::shared_ptr<mkldnn::memory> bias_mem; 178 std::shared_ptr<mkldnn::memory> dst_mem; 179 180 // Desc & primitive desc 181 std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc; 182 183 // Memory desc 184 std::shared_ptr<mkldnn::memory::desc> src_md; 185 std::shared_ptr<mkldnn::memory::desc> filter_md; 186 std::shared_ptr<mkldnn::memory::desc> bias_md; 187 std::shared_ptr<mkldnn::memory::desc> dst_md; 188 189 // Convolution primitive 190 std::shared_ptr<ConvFwdPd> fwd_pd; 191 std::shared_ptr<mkldnn::primitive> conv_fwd; 192 193 std::vector<mkldnn::primitive> fwd_primitives; 194 std::vector<std::unordered_map<int, memory>> fwd_primitives_args; 195 ConvFwdContexttensorflow::MklConvFwdPrimitive::ConvFwdContext196 ConvFwdContext() 197 : src_mem(nullptr), 198 filter_mem(nullptr), 199 bias_mem(nullptr), 200 dst_mem(nullptr), 201 fwd_desc(nullptr), 202 src_md(nullptr), 203 filter_md(nullptr), 204 bias_md(nullptr), 205 fwd_pd(nullptr), 206 conv_fwd(nullptr) {} 207 }; 208 Setup(const MklConvFwdParams & convFwdDims)209 void Setup(const MklConvFwdParams& convFwdDims) { 210 memory::format_tag user_data_fmt; 211 if (convFwdDims.native_format) { 212 user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt); 213 } else { 214 // Create memory descriptors for convolution data w/ no specified format 215 user_data_fmt = memory::format_tag::any; 216 } 217 context_.src_md.reset(new memory::desc( 218 {convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt)); 219 220 context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims}, 221 MklDnnType<Tfilter>(), 222 memory::format_tag::any)); 223 224 context_.dst_md.reset(new memory::desc( 225 {convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt)); 226 227 if (!convFwdDims.bias_dims.empty()) 228 context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, 229 MklDnnType<Tbias>(), 230 memory::format_tag::any)); 231 232 // Create a convolution descriptor 233 if (!convFwdDims.bias_dims.empty()) { 234 context_.fwd_desc.reset(new convolution_forward::desc( 235 prop_kind::forward, mkldnn::algorithm::convolution_direct, 236 *context_.src_md, *context_.filter_md, *context_.bias_md, 237 *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, 238 convFwdDims.padding_left, convFwdDims.padding_right)); 239 } else { 240 context_.fwd_desc.reset(new convolution_forward::desc( 241 prop_kind::forward, mkldnn::algorithm::convolution_direct, 242 *context_.src_md, *context_.filter_md, *context_.dst_md, 243 convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, 244 convFwdDims.padding_right)); 245 } 246 247 context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); 248 249 // Check if there is any fusions as post-ops 250 auto const& post_op_params = convFwdDims.post_op_params; 251 mkldnn::primitive_attr post_ops_attr; 252 mkldnn::post_ops post_ops; 253 if (!post_op_params.empty()) { 254 for (auto const& post_op_param : post_op_params) { 255 if (post_op_param.name == "activation") { 256 DCHECK_EQ(post_op_param.param.size(), 3); 257 float op_scale = post_op_param.param[0]; 258 float op_alpha = post_op_param.param[1]; 259 float op_beta = post_op_param.param[2]; 260 post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha, 261 op_beta); 262 } else if (post_op_param.name == "sum") { 263 DCHECK_EQ(post_op_param.param.size(), 1); 264 float op_scale = post_op_param.param[0]; 265 post_ops.append_sum(op_scale); 266 } else if (post_op_param.name == "output_scale") { 267 if (post_op_param.param.size() == 1) { 268 post_ops_attr.set_output_scales(0, post_op_param.param); 269 } else { 270 post_ops_attr.set_output_scales(2, post_op_param.param); 271 } 272 } else { 273 DCHECK((post_op_param.name == "activation") || 274 (post_op_param.name == "sum") || 275 (post_op_param.name == "output_scale")); 276 } 277 } 278 post_ops_attr.set_post_ops(post_ops); 279 context_.fwd_pd.reset( 280 new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_)); 281 } else { 282 context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); 283 } 284 285 // Create memory primitive based on dummy data 286 context_.src_mem.reset( 287 new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData)); 288 context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(), 289 cpu_engine_, DummyData)); 290 context_.dst_mem.reset( 291 new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); 292 293 // Create convolution primitive and add it to net 294 if (!convFwdDims.bias_dims.empty()) { 295 context_.bias_mem.reset(new memory( 296 {{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x}, 297 cpu_engine_, DummyData)); 298 context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); 299 context_.fwd_primitives_args.push_back( 300 {{MKLDNN_ARG_SRC, *context_.src_mem}, 301 {MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, 302 {MKLDNN_ARG_BIAS, *context_.bias_mem}, 303 {MKLDNN_ARG_DST, *context_.dst_mem}}); 304 } else { 305 context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); 306 context_.fwd_primitives_args.push_back( 307 {{MKLDNN_ARG_SRC, *context_.src_mem}, 308 {MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, 309 {MKLDNN_ARG_DST, *context_.dst_mem}}); 310 } 311 context_.fwd_primitives.push_back(*context_.conv_fwd); 312 } 313 314 struct ConvFwdContext context_; 315 }; 316 317 // TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory. 318 // But removing the need for type in MklPrimitiveFactory is going to require 319 // change to every MKL op. So not doing it now. Instead passing float. 320 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput> 321 class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<float> { 322 public: Get(const MklConvFwdParams & convFwdDims,bool do_not_cache)323 static MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* Get( 324 const MklConvFwdParams& convFwdDims, bool do_not_cache) { 325 MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* conv_fwd = nullptr; 326 327 if (do_not_cache) { 328 // Always create a new primitive 329 conv_fwd = 330 new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>(convFwdDims); 331 } else { 332 // Try to find a suitable one in pool 333 conv_fwd = 334 dynamic_cast<MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>*>( 335 MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, 336 Toutput>::GetInstance() 337 .GetConvFwd(convFwdDims)); 338 if (conv_fwd == nullptr) { 339 conv_fwd = new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>( 340 convFwdDims); 341 MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, 342 Toutput>::GetInstance() 343 .SetConvFwd(convFwdDims, conv_fwd); 344 } 345 } 346 347 return conv_fwd; 348 } 349 350 private: MklConvFwdPrimitiveFactory()351 MklConvFwdPrimitiveFactory() {} ~MklConvFwdPrimitiveFactory()352 ~MklConvFwdPrimitiveFactory() {} 353 354 static const int kDilationH = 0, kDilationW = 1; 355 GetInstance()356 static MklConvFwdPrimitiveFactory& GetInstance() { 357 static MklConvFwdPrimitiveFactory instance_; 358 return instance_; 359 } 360 CreateKey(const MklConvFwdParams & convFwdDims)361 static string CreateKey(const MklConvFwdParams& convFwdDims) { 362 string prefix = "conv_fwd_"; 363 FactoryKeyCreator key_creator; 364 key_creator.AddAsKey(prefix); 365 key_creator.AddAsKey(convFwdDims.src_dims); 366 key_creator.AddAsKey(convFwdDims.filter_dims); 367 key_creator.AddAsKey(convFwdDims.bias_dims); 368 key_creator.AddAsKey(convFwdDims.dst_dims); 369 key_creator.AddAsKey(convFwdDims.strides); 370 key_creator.AddAsKey(convFwdDims.dilations); 371 key_creator.AddAsKey(convFwdDims.padding_left); 372 key_creator.AddAsKey(convFwdDims.padding_right); 373 key_creator.AddAsKey(convFwdDims.dtypes); 374 if (convFwdDims.native_format) { 375 key_creator.AddAsKey(convFwdDims.tf_fmt); 376 } 377 378 // Generate keys for post-ops 379 for (auto const& post_op_param : convFwdDims.post_op_params) { 380 key_creator.AddAsKey(post_op_param.name); 381 if (post_op_param.name == "activation") { 382 DCHECK_EQ(post_op_param.param.size(), 3); 383 for (auto& param : post_op_param.param) { 384 key_creator.AddAsKey(param); 385 } 386 } else if (post_op_param.name == "sum") { 387 DCHECK_EQ(post_op_param.param.size(), 1); 388 for (auto& param : post_op_param.param) { 389 key_creator.AddAsKey(param); 390 } 391 } else if (post_op_param.name == "output_scale") { 392 key_creator.AddAsKey(post_op_param.partial_key); 393 } else { 394 return string("not_a_key"); 395 } 396 } 397 398 return key_creator.GetKey(); 399 } 400 GetConvFwd(const MklConvFwdParams & convFwdDims)401 MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) { 402 string key = CreateKey(convFwdDims); 403 return this->GetOp(key); 404 } 405 SetConvFwd(const MklConvFwdParams & convFwdDims,MklPrimitive * op)406 void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { 407 string key = CreateKey(convFwdDims); 408 this->SetOp(key, op); 409 } 410 }; 411 412 // Base class for convolution forward operations 413 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 414 typename Toutput, typename Ttemp_output, typename Tpadding, 415 bool bias_enabled, bool pad_enabled, bool is_depthwise, 416 bool native_format> 417 class MklConvOp : public OpKernel { 418 public: ~MklConvOp()419 ~MklConvOp() {} 420 MklConvOp(OpKernelConstruction * context)421 explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { 422 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 423 424 // Conv and QuantizedConv ops have different padding attributes 425 // (`padding_list` versus `explicit_paddings`). But one and only one 426 // attribute is expected. 427 OP_REQUIRES( 428 context, 429 !(context->HasAttr("padding_list") && 430 context->HasAttr("explicit_paddings")), 431 errors::InvalidArgument("Can only have 1 `padding` list at most")); 432 if (context->HasAttr("padding_list")) { 433 OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); 434 } 435 if (context->HasAttr("explicit_paddings")) { 436 OP_REQUIRES_OK(context, 437 context->GetAttr("explicit_paddings", &padding_list_)); 438 } 439 440 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 441 string data_format; 442 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 443 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 444 errors::InvalidArgument("Invalid data format")); 445 OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), 446 errors::InvalidArgument("Sliding window strides field must " 447 "specify 4 or 5 dimensions")); 448 449 const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); 450 const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); 451 OP_REQUIRES( 452 context, stride_n == 1 && stride_c == 1, 453 errors::Unimplemented("Current implementation does not yet support " 454 "strides in the batch and depth dimensions.")); 455 456 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 457 is_filter_const_ = false; 458 if (context->HasAttr("is_filter_const")) { 459 OP_REQUIRES_OK(context, 460 context->GetAttr("is_filter_const", &is_filter_const_)); 461 } 462 463 if (strides_.size() == 4) { 464 OP_REQUIRES(context, dilations_.size() == 4, 465 errors::InvalidArgument("Sliding window dilations field must " 466 "specify 4 dimensions")); 467 const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 468 const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 469 const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 470 const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 471 OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, 472 errors::InvalidArgument( 473 "Current implementation does not yet support " 474 "dilations in the batch and depth dimensions.")); 475 OP_REQUIRES( 476 context, dilation_h > 0 && dilation_w > 0, 477 errors::InvalidArgument("Dilated rates should be larger than 0.")); 478 } else if (strides_.size() == 5) { 479 OP_REQUIRES(context, dilations_.size() == 5, 480 errors::InvalidArgument("Dilation rates field must " 481 "specify 5 dimensions")); 482 OP_REQUIRES(context, 483 (GetTensorDim(dilations_, data_format_, 'N') == 1 && 484 GetTensorDim(dilations_, data_format_, 'C') == 1), 485 errors::InvalidArgument( 486 "Current implementation does not yet support " 487 "dilations rates in the batch and depth dimensions.")); 488 OP_REQUIRES( 489 context, 490 (GetTensorDim(dilations_, data_format_, '0') > 0 && 491 GetTensorDim(dilations_, data_format_, '1') > 0 && 492 GetTensorDim(dilations_, data_format_, '2') > 0), 493 errors::InvalidArgument("Dilated rates should be larger than 0.")); 494 } 495 } 496 Compute(OpKernelContext * context)497 void Compute(OpKernelContext* context) override { 498 try { 499 // Input tensors 500 const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src); 501 const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); 502 MklDnnShape src_mkl_shape, filter_mkl_shape; 503 GetMklShape(context, kInputIndex_Src, &src_mkl_shape, native_format); 504 GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, 505 native_format); 506 507 OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(), 508 errors::InvalidArgument("Filter should not be in " 509 "Mkl Layout")); 510 511 MklDnnData<Tinput> src(&cpu_engine_); 512 MklDnnData<Tfilter> filter(&cpu_engine_); 513 514 memory::dims src_dims, filter_dims, padding_left, padding_right, 515 dilations, strides; 516 memory::dims dst_dims_tf_order, dst_dims_mkl_order; 517 518 // For any Conv with `EXPLICIT` padding, get padding from `padding_list` 519 // attribute. Otherwise, get it from one of the inputs. 520 bool pad_attr_enabled = false; 521 for (auto const& padding_val : padding_list_) { 522 if (padding_val) { 523 pad_attr_enabled = true; 524 525 break; 526 } 527 } 528 529 if (fuse_pad_ || pad_attr_enabled) { 530 PadWithConvFusion(context, padding_left, padding_right, 531 pad_attr_enabled); 532 } 533 534 // Get shapes of input tensors in MKL-DNN order 535 MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_, 536 dilations_); 537 auto src_tf_shape = GetTfShape(context, kInputIndex_Src, native_format); 538 auto filter_tf_shape = 539 GetTfShape(context, kInputIndex_Filter, native_format); 540 conv_utl.GetConvFwdSizesInMklOrder( 541 src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides, 542 &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left, 543 &padding_right, (fuse_pad_ || pad_attr_enabled), is_depthwise); 544 545 if (!context->status().ok()) return; 546 547 // Check for corner case - if there is nothing to compute, return. 548 TensorShape dst_tf_shape = MklDnnDimsToTFShape(dst_dims_tf_order); 549 550 // Corner cases: output with 0 elements and 0 batch size. 551 Tensor* dst_tensor = nullptr; 552 bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) && 553 typeid(Tinput) == typeid(Toutput) && 554 (typeid(Tinput) == typeid(float) || 555 typeid(Tinput) == typeid(bfloat16))) && 556 !native_format; 557 if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) { 558 MklDnnShape dst_mkl_shape; 559 dst_mkl_shape.SetMklTensor(false); 560 AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor, 561 src_tf_shape, dst_mkl_shape, native_format); 562 563 // MklConv2D/3D also outputs converted filter as 2nd output. 564 filter_mkl_shape.SetMklTensor(false); 565 Tensor* output_filter_tensor = nullptr; 566 if (emit_filter_output) { 567 filter_mkl_shape.SetMklTensor(false); 568 AllocateOutputSetMklShape(context, kOutputIndex_Filter, 569 &output_filter_tensor, filter_tf_shape, 570 filter_mkl_shape); 571 } 572 return; 573 } 574 575 bool is_conv2d = (strides_.size() == 4); 576 577 if (!is_conv2d) { 578 OP_REQUIRES( 579 context, !pad_enabled, 580 errors::InvalidArgument("Pad + Conv fusion only works for 2D")); 581 OP_REQUIRES( 582 context, !fuse_pad_, 583 errors::InvalidArgument("Pad+Conv fusion only works for 2D")); 584 } 585 586 // TODO(gzmkl) 3-D support for Depthwise is not there 587 if (is_depthwise) { 588 OP_REQUIRES(context, is_conv2d, 589 errors::InvalidArgument( 590 "Only 2D convolution is supported for depthwise.")); 591 } 592 593 // Create memory for user data. 594 // Describe how the inputs and outputs of Convolution look like. Also 595 // specify buffers containing actual input and output data. 596 auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_) 597 : TFDataFormatToMklDnn3DDataFormat(data_format_); 598 599 auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); 600 // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU 601 OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, 602 errors::InvalidArgument("Invalid data format")); 603 604 // If input is in MKL layout, then simply grab the layout; otherwise, 605 // construct TF layout for input. 606 // For constructing TF layout for input, although input shape (src_dims) 607 // is required to be in MKL-DNN order, the input layout is actually in 608 // TF layout depending on the data format: 609 // Conv2D: NHWC or NCHW 610 // Conv3D: NDHWC or NCDHW 611 auto src_md = 612 src_mkl_shape.IsMklTensor() 613 ? src_mkl_shape.GetMklLayout() 614 : memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag); 615 src.SetUsrMem(src_md, &src_tensor); 616 617 // Although filter shape (filter_dims) required is in MKL-DNN order, 618 // the layout is Tensorflow's layout (HWIO) and (HWIGO) for 619 // depthwise/group convolutions. 620 auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo 621 : memory::format_tag::hwio) 622 : memory::format_tag::dhwio; 623 624 DCHECK(!filter_mkl_shape.IsMklTensor()); 625 auto filter_md = 626 filter_mkl_shape.IsMklTensor() 627 ? filter_mkl_shape.GetMklLayout() 628 : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format); 629 filter.SetUsrMem(filter_md, &filter_tensor); 630 631 // MKL-DNN dilations start from 0. 632 for (int i = 0; i < dilations.size(); ++i) --dilations[i]; 633 634 // In some cases, primitive descriptor could potentially contain 635 // large buffers. As a result, we don't cache these primitives if the 636 // environment variable `TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE` is set to True. 637 // MKL-DNN allocates buffers in the following cases: 638 // 1. Legacy CPU without AVX512/AVX2, or 639 // 2. 1x1 convolution with strides != 1 640 bool do_not_cache = 641 MklPrimitiveFactory<Tinput>::IsPrimitiveMemOptEnabled() && 642 (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) && 643 (MklPrimitiveFactory<Tinput>::IsLegacyPlatform() || 644 IsConv1x1StrideNot1(filter_dims, strides)); 645 646 // Get a conv2d fwd from primitive pool 647 MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Ttemp_output>* conv_fwd = 648 nullptr; 649 memory::dims bias_dims = {}; 650 if (fuse_biasadd_) { 651 conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); 652 } 653 MklConvFwdParams convFwdDims( 654 src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS, 655 dst_dims_mkl_order, strides, dilations, padding_left, padding_right, 656 tf_fmt, native_format); 657 658 // TODO(mdfaijul): Extend the basic parameters for data types and fusions 659 this->ExtendConvFwdParams(context, convFwdDims); 660 conv_fwd = 661 MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get( 662 convFwdDims, do_not_cache); 663 // Allocate output tensors `dst_tensor` and `filter_out_tensor` 664 MklDnnShape output_mkl_shape; 665 std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc(); 666 AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, 667 &output_mkl_shape, &dst_tensor); 668 669 Tensor* filter_out_tensor = nullptr; 670 if (emit_filter_output) { 671 AllocateFilterOutputTensor(context, *conv_fwd_pd, 672 TFShapeToMklDnnDims(filter_tf_shape), 673 &filter_out_tensor); 674 } 675 676 Ttemp_output* dst_data = 677 reinterpret_cast<Ttemp_output*>(dst_tensor->flat<Toutput>().data()); 678 679 // Check whether src and filter need to be reordered. 680 Tinput* src_data = nullptr; 681 if (src_md != conv_fwd_pd->src_desc()) { 682 src.SetUsrMem(src_md, &src_tensor); 683 src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context); 684 src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); 685 } else { 686 src_data = static_cast<Tinput*>( 687 const_cast<Tinput*>(src_tensor.flat<Tinput>().data())); 688 } 689 690 Tfilter* filter_data = nullptr; 691 if (filter_md != conv_fwd_pd->weights_desc()) { 692 bool is_filter_cached = false; 693 // If filter is a constant, we can avoid the conversion of filter from 694 // Tensorflow format to MKL format by caching the filter when it is 695 // converted for the first time. This cached filter can then be reused 696 // in subsequent iterations. 697 if (is_filter_const_) { 698 if (IsFilterCacheEmpty(context)) { 699 // Cache filter if it is not already cached. 700 CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, 701 filter, filter_md, filter_mkl_shape); 702 } 703 filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc()); 704 is_filter_cached = (filter_data != nullptr); 705 } 706 if (!is_filter_cached) { 707 filter.SetUsrMem(filter_md, &filter_tensor); 708 if (filter_out_tensor == nullptr) { 709 filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_, 710 context); 711 } else { 712 filter.CheckReorderToOpMem( 713 conv_fwd_pd->weights_desc(), 714 filter.GetTensorBuffer(filter_out_tensor), cpu_engine_, 715 context); 716 } 717 filter_data = 718 static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); 719 } 720 } else { 721 filter_data = static_cast<Tfilter*>( 722 const_cast<Tfilter*>(filter_tensor.flat<Tfilter>().data())); 723 } 724 725 // Execute convolution 726 std::shared_ptr<stream> fwd_cpu_stream; 727 fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine())); 728 if (fuse_biasadd_) { 729 const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); 730 Tbias* bias_data = 731 this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); 732 conv_fwd->Execute(src_data, filter_data, bias_data, dst_data, 733 fwd_cpu_stream); 734 } else { 735 conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream); 736 } 737 738 // Delete primitive since it is not cached. 739 if (do_not_cache) delete conv_fwd; 740 741 } catch (mkldnn::error& e) { 742 string error_msg = tensorflow::strings::StrCat( 743 "Status: ", e.status, ", message: ", string(e.message), ", in file ", 744 __FILE__, ":", __LINE__); 745 OP_REQUIRES_OK( 746 context, 747 errors::Aborted("Operation received an exception:", error_msg)); 748 } 749 } 750 PadWithConvFusion(OpKernelContext * context,memory::dims & padding_left,memory::dims & padding_right,bool pad_attr_enabled)751 void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left, 752 memory::dims& padding_right, bool pad_attr_enabled) { 753 Tpadding* paddings = nullptr; 754 if (pad_attr_enabled) { 755 paddings = padding_list_.data(); 756 } else { 757 const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); 758 OP_REQUIRES(context, paddings_tf.dims() == 2, 759 errors::InvalidArgument("paddings must be 2-dimensional: ", 760 paddings_tf.shape().DebugString())); 761 // Flatten tensor to get individual paddings. 762 paddings = static_cast<Tpadding*>( 763 const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data())); 764 } 765 // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf) 766 // will be zero. 767 // Example: 768 // paddings_tf = [ [0, 0] [1, 2] [3, 4] [0, 0] ], 769 // flat method = row-major, then: 770 // paddings = {0, 0, 1, 2, 3, 4, 0, 0}. 771 // Hence, the values are: top = 1, bottom = 2, left = 3, right = 4. 772 // 773 // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of 774 // paddings(_tf) will be zero. 775 // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}. 776 int64 pad_top = 0, pad_left = 0; 777 int64 pad_bottom = 0, pad_right = 0; 778 string data_format = ToString(data_format_); 779 if (data_format == "NHWC") { 780 pad_top = paddings[2]; 781 pad_bottom = paddings[3]; 782 pad_left = paddings[4]; 783 pad_right = paddings[5]; 784 } else if (data_format == "NCHW") { 785 pad_top = paddings[4]; 786 pad_bottom = paddings[5]; 787 pad_left = paddings[6]; 788 pad_right = paddings[7]; 789 } 790 // Create padding arrays for MKL-DNN convolutions. 791 // MKL-DNN uses asymmetric padding. 792 padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; 793 padding_right = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; 794 } 795 796 protected: set_fuse_biasadd(bool fuse_biasadd)797 void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; } set_fuse_activation(bool fuse_activation,mkldnn::algorithm activation_alg,float alpha_or_upbound=0.0)798 void set_fuse_activation(bool fuse_activation, 799 mkldnn::algorithm activation_alg, 800 float alpha_or_upbound = 0.0) { 801 fuse_activation_ = fuse_activation; 802 activation_alg_ = activation_alg; 803 // This variable is used for alpha in leakyrelu or upper bound in relu6 804 // depending on the context 805 alpha_or_upbound_ = alpha_or_upbound; 806 } set_fuse_pad(bool fuse_pad)807 void set_fuse_pad(bool fuse_pad) { 808 fuse_pad_ = fuse_pad; 809 // In PadwithFusedConv OP, pad is the fourth index. 810 input_index_pad_ = 3; 811 } set_fuse_add(bool fuse_add)812 void set_fuse_add(bool fuse_add) { fuse_add_ = fuse_add; } 813 814 // This method is for the base class MklConvOp, which handles the 815 // floating point implementation of Conv. The quantized conv implementations 816 // will use overridden versions of this method. ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)817 virtual void ExtendConvFwdParams(OpKernelContext* context, 818 MklConvFwdParams& params) { 819 // Create a string from data types of input, filter, bias, and output. 820 params.dtypes.append(typeid(Tinput).name()); 821 params.dtypes.append(typeid(Tfilter).name()); 822 params.dtypes.append(typeid(Tbias).name()); 823 params.dtypes.append(typeid(Toutput).name()); 824 825 // Add fusions as post ops 826 // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by 827 // checking `fuse_biasadd_` flag. 828 if (fuse_add_) { 829 params.post_op_params.push_back( 830 {"sum", mkldnn::algorithm::undef, {1.0}, ""}); 831 } 832 if (fuse_activation_) { 833 params.post_op_params.push_back( 834 {"activation", activation_alg_, {1.0, alpha_or_upbound_, 0.0}, ""}); 835 } 836 } 837 GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv2d_fwd_pd,const Tensor & bias_tensor)838 virtual Tbias* GetBiasHandle(OpKernelContext* context, 839 std::shared_ptr<ConvFwdPd>& conv2d_fwd_pd, 840 const Tensor& bias_tensor) { 841 if (fuse_biasadd_) { 842 return static_cast<Tbias*>( 843 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 844 } 845 return nullptr; 846 } 847 AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)848 virtual void AllocateOutputTensor(OpKernelContext* context, 849 const ConvFwdPd& conv_prim_desc, 850 const memory::dims& output_dims_mkl_order, 851 MklTensorFormat output_tf_format, 852 MklDnnShape* output_mkl_shape, 853 Tensor** output_tensor) { 854 DCHECK(output_tensor); 855 auto dst_md = conv_prim_desc.dst_desc(); 856 857 if (!std::is_same<Ttemp_output, Toutput>::value) { 858 dst_md.data.data_type = 859 static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>()); 860 } 861 862 // Allocate shape of MKL tensor 863 output_mkl_shape->SetMklTensor(true); 864 output_mkl_shape->SetMklLayout(&dst_md); 865 output_mkl_shape->SetElemType(MklDnnType<Toutput>()); 866 output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(), 867 output_dims_mkl_order, output_tf_format); 868 869 // Allocate shape of TF tensor 870 TensorShape output_tf_shape; 871 output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput))); 872 if (native_format) { 873 output_tf_shape = output_mkl_shape->GetTfShape(); 874 } 875 876 if (fuse_add_) { 877 const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add); 878 MklDnnShape add_mkl_shape; 879 GetMklShape(context, kInputIndex_Add, &add_mkl_shape, native_format); 880 // Forward the summand tensor to the output only if it has no other 881 // references, otherwise make a copy of it. 882 if (native_format && context->forward_input_to_output_with_shape( 883 kInputIndex_Add, kOutputIndex_Dst, 884 output_tf_shape, output_tensor)) { 885 return; 886 } 887 // Check if reorder is needed 888 if (!native_format && add_mkl_shape == *output_mkl_shape && 889 ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, 890 kOutputIndex_Dst, output_tensor, 891 add_mkl_shape, false)) { 892 return; 893 } else { 894 AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, 895 output_tf_shape, *output_mkl_shape, 896 native_format); 897 auto output_format_tag = MklTensorFormatToMklDnnDataFormat( 898 output_mkl_shape->GetTfDataFormat()); 899 OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, 900 errors::InvalidArgument( 901 "MklConvOp: AddN fusion: Invalid data format")); 902 auto add_md = 903 add_mkl_shape.IsMklTensor() 904 ? add_mkl_shape.GetMklLayout() 905 : memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(), 906 output_format_tag); 907 void* add_buf = static_cast<void*>( 908 const_cast<Toutput*>(add_tensor.flat<Toutput>().data())); 909 void* dst_buf = 910 static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); 911 if (native_format) { 912 // We are simply deep copying the add_tensor to output_tensor without 913 // changing memory layout, hence using same memory descriptor. 914 add_md = dst_md = 915 memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(), 916 mkldnn::memory::format_tag::x); 917 } 918 fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf)); 919 fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf)); 920 auto reorder_desc = 921 ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md); 922 923 CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_, 924 this->cpu_engine_, context); 925 } 926 } else { 927 AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, 928 output_tf_shape, *output_mkl_shape, 929 native_format); 930 } 931 } 932 933 engine cpu_engine_ = engine(engine::kind::cpu, 0); 934 935 private: 936 std::shared_ptr<mkldnn::memory> fuse_add_src_; 937 std::shared_ptr<mkldnn::memory> fuse_add_dst_; 938 std::vector<int32> strides_; 939 std::vector<int32> dilations_; 940 std::vector<Tpadding> padding_list_; 941 bool is_filter_const_; 942 mutex mu_; 943 Padding padding_; 944 TensorFormat data_format_; 945 PersistentTensor cached_filter_data_ptensor_ TF_GUARDED_BY(mu_); 946 PersistentTensor cached_filter_md_ptensor_ TF_GUARDED_BY(mu_); 947 948 // Initialize to values the template is instantiated with 949 bool fuse_biasadd_ = bias_enabled; 950 bool fuse_activation_ = false; 951 bool fuse_pad_ = pad_enabled; 952 bool fuse_add_ = false; 953 954 // This variable is used for alpha in leakyrelu or upper bound in relu6 955 // depending on the context 956 float alpha_or_upbound_ = 0.0; 957 mkldnn::algorithm activation_alg_ = mkldnn::algorithm::undef; 958 959 int input_index_pad_ = 2; 960 961 const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2; 962 const int kInputIndex_Add = 3; 963 const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; 964 const int kDilationH = 0, kDilationW = 1; 965 GetFilterTfDataFormat(const MklDnnShape * filter_mkl_shape,const ConvFwdPd & conv_prim_desc) const966 MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape, 967 const ConvFwdPd& conv_prim_desc) const { 968 DCHECK(filter_mkl_shape); 969 return filter_mkl_shape->GetTfDataFormat(); 970 } 971 972 // Allocate persistent tensors for cached filter data and 973 // cached filter memory descriptor (data format) AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor,const MklDnnShape * filter_mkl_shape)974 void AllocatePersistentTensor(OpKernelContext* context, 975 const ConvFwdPd& conv_prim_desc, 976 Tensor** filter_tensor, 977 const MklDnnShape* filter_mkl_shape) { 978 DCHECK(filter_tensor); 979 TensorShape filter_tf_shape; 980 filter_tf_shape.AddDim( 981 (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter))); 982 OP_REQUIRES_OK(context, context->allocate_persistent( 983 DataTypeToEnum<Tfilter>::value, filter_tf_shape, 984 &cached_filter_data_ptensor_, filter_tensor)); 985 986 Tensor* second_tensor = nullptr; 987 988 // There is no tensor format in DNNL 1.x. So we cache the complete filter 989 // descriptor as flat byte array. 990 TensorShape cached_filter_md_shape; 991 memory::desc weights_desc = conv_prim_desc.weights_desc(); 992 // We don't use .get_size() method of memory::desc since it returns size 993 // required to store primitive's input memory. It is much more than size of 994 // memory::desc itself. 995 cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8)); 996 OP_REQUIRES_OK(context, context->allocate_persistent( 997 DT_UINT8, cached_filter_md_shape, 998 &cached_filter_md_ptensor_, &second_tensor)); 999 *reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) = 1000 weights_desc; 1001 } 1002 AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor)1003 void AllocatePersistentTensor(OpKernelContext* context, 1004 const ConvFwdPd& conv_prim_desc, 1005 Tensor** filter_tensor) { 1006 AllocatePersistentTensor(context, conv_prim_desc, filter_tensor, nullptr); 1007 } 1008 AllocateFilterOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & filter_dims_tf_order,Tensor ** filter_tensor)1009 void AllocateFilterOutputTensor(OpKernelContext* context, 1010 const ConvFwdPd& conv_prim_desc, 1011 const memory::dims& filter_dims_tf_order, 1012 Tensor** filter_tensor) { 1013 DCHECK(filter_tensor); 1014 auto filter_md = conv_prim_desc.weights_desc(); 1015 1016 // Allocate shape of MKL tensor 1017 MklDnnShape filter_mkl_shape; 1018 filter_mkl_shape.SetMklTensor(true); 1019 filter_mkl_shape.SetMklLayout(&filter_md); 1020 filter_mkl_shape.SetElemType(MklDnnType<Tfilter>()); 1021 1022 // The format of the filter is actually OIhw8i8o, but TF doesn't support 1023 // this format. Just use format::blocked for now because the layout 1024 // is stored in the MKL data. 1025 filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), 1026 filter_dims_tf_order, 1027 MklTensorFormat::FORMAT_BLOCKED); 1028 1029 // Allocate the data space for the filter to propagate as TF tensor. 1030 TensorShape filter_tf_shape; 1031 filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter))); 1032 1033 AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, 1034 filter_tf_shape, filter_mkl_shape); 1035 } 1036 1037 // TODO(intel-mkl): This function does not seem to be called. Remove it. 1038 // Prepare and execute net - checks for input and output reorders. PrepareAndExecuteNet(const ConvFwdPd & conv_prim_desc,MklDnnData<Tinput> * src,MklDnnData<Tfilter> * filter,MklDnnData<Tbias> * bias,MklDnnData<Toutput> * output,Tensor * filter_out_tensor)1039 void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc, 1040 MklDnnData<Tinput>* src, 1041 MklDnnData<Tfilter>* filter, 1042 MklDnnData<Tbias>* bias, 1043 MklDnnData<Toutput>* output, 1044 Tensor* filter_out_tensor) { 1045 DCHECK(filter_out_tensor); 1046 1047 // Create reorders between user layout and MKL layout if it is needed and 1048 // add it to the net before convolution. No need to check for output 1049 // reorder as we propagate output layout to the next layer. 1050 src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_); 1051 1052 // Rather than re-ordering to a temp buffer, reorder directly to the 1053 // filter output tensor 1054 filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(), 1055 filter->GetTensorBuffer(filter_out_tensor)); 1056 1057 // Create convolution primitive and add it to net. 1058 std::vector<primitive> net; 1059 std::vector<std::unordered_map<int, memory>> net_args; 1060 if (bias) { 1061 DCHECK(fuse_biasadd_); 1062 net.push_back(convolution_forward(conv_prim_desc)); 1063 net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, 1064 {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, 1065 {MKLDNN_ARG_BIAS, bias->GetOpMem()}, 1066 {MKLDNN_ARG_DST, output->GetOpMem()}}); 1067 } else { 1068 DCHECK(!fuse_biasadd_); 1069 net.push_back(convolution_forward(conv_prim_desc)); 1070 net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, 1071 {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, 1072 {MKLDNN_ARG_DST, output->GetOpMem()}}); 1073 } 1074 ExecutePrimitive(net, &net_args, cpu_engine_); 1075 } 1076 1077 // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot 1078 // be acquired before entering the function, since it is acquired 1079 // inside the function. IsFilterCacheEmpty(OpKernelContext * context)1080 inline bool IsFilterCacheEmpty(OpKernelContext* context) 1081 TF_LOCKS_EXCLUDED(mu_) { 1082 tf_shared_lock lock(mu_); 1083 const Tensor& cached_filter_data_tensor = 1084 *cached_filter_data_ptensor_.AccessTensor(context); 1085 return (cached_filter_data_tensor.NumElements() == 0); 1086 } 1087 1088 // Cache the converted filter in a persistent tensor. 1089 // Only one thread can execute this method at any given time. CacheFilter(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tfilter * filter_data,const Tensor & filter_tensor,MklDnnData<Tfilter> & filter,const memory::desc & filter_md,const MklDnnShape & filter_mkl_shape)1090 void CacheFilter(OpKernelContext* context, 1091 const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1092 Tfilter* filter_data, const Tensor& filter_tensor, 1093 MklDnnData<Tfilter>& filter, const memory::desc& filter_md, 1094 const MklDnnShape& filter_mkl_shape) TF_LOCKS_EXCLUDED(mu_) { 1095 mutex_lock lock(mu_); 1096 const Tensor& cached_filter_data_tensor = 1097 *cached_filter_data_ptensor_.AccessTensor(context); 1098 1099 // If filter is already cached, there's nothing to do. 1100 if (cached_filter_data_tensor.NumElements() > 0) { 1101 return; 1102 } 1103 1104 // Otherwise, cache filter 1105 filter.SetUsrMem(filter_md, &filter_tensor); 1106 filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(), 1107 this->cpu_engine_, context); 1108 filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); 1109 1110 Tensor* filter_tensor_ptr = nullptr; 1111 AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr, 1112 &filter_mkl_shape); 1113 void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr); 1114 size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size(); 1115 memcpy(cached_filter_data, filter_data, cached_filter_data_size); 1116 } 1117 AreMemoryDescriptorsEqual(const memory::desc & filter_md,const Tensor & cached_filter_md)1118 bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, 1119 const Tensor& cached_filter_md) { 1120 auto filter_md_data = filter_md.data; 1121 const char* filter_data = reinterpret_cast<const char*>(&filter_md_data); 1122 1123 auto cached_filter_md_data = cached_filter_md.scalar<int64>()(); 1124 const char* cached_filter_data = 1125 reinterpret_cast<const char*>(&cached_filter_md_data); 1126 1127 for (size_t i = 0; i < sizeof(filter_md_data); ++i) { 1128 if (*filter_data++ != *cached_filter_data++) { 1129 return false; 1130 } 1131 } 1132 return true; 1133 } 1134 GetCachedFilter(OpKernelContext * context,const memory::desc & filter_md)1135 Tfilter* GetCachedFilter(OpKernelContext* context, 1136 const memory::desc& filter_md) 1137 TF_LOCKS_EXCLUDED(mu_) { 1138 tf_shared_lock lock(mu_); 1139 const Tensor& cached_filter_data = 1140 *cached_filter_data_ptensor_.AccessTensor(context); 1141 const Tensor& cached_filter_md = 1142 *cached_filter_md_ptensor_.AccessTensor(context); 1143 1144 // Check if the memory descriptor of the cached weights is the same as 1145 // filter_md. If so, we can use the cached weights; otherwise 1146 // return nullptr. 1147 if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) { 1148 return static_cast<Tfilter*>( 1149 const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data())); 1150 } 1151 return nullptr; 1152 } 1153 }; 1154 1155 // Base class for fused convolution forward operations 1156 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 1157 typename Toutput, typename Ttemp_output, typename Tpadding, 1158 bool pad_enabled, bool native_format> 1159 class MklFusedConvOp 1160 : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1161 Tpadding, false, false, false, native_format> { 1162 public: MklFusedConvOp(OpKernelConstruction * context)1163 explicit MklFusedConvOp(OpKernelConstruction* context) 1164 : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1165 Tpadding, false, false, false, native_format>(context) { 1166 // Since we came here through the registration of _MklFusedConv2D, get 1167 // all information from 'fused_ops' and 'num_args' 1168 std::vector<string> fused_ops; 1169 OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); 1170 1171 int num_args; 1172 OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); 1173 OP_REQUIRES(context, !fused_ops.empty(), 1174 errors::InvalidArgument( 1175 "Fused Conv2D must have at least one fused op.")); 1176 1177 if (fused_ops == std::vector<string>{"BiasAdd"}) { 1178 this->set_fuse_biasadd(true); 1179 OP_REQUIRES(context, num_args == 1, 1180 errors::InvalidArgument( 1181 "Fused Conv2D must have one extra argument: bias.")); 1182 } else if (fused_ops == std::vector<string>{"Relu"}) { 1183 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu); 1184 } else if (fused_ops == std::vector<string>{"Relu6"}) { 1185 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu, 1186 6.0); 1187 } else if (fused_ops == std::vector<string>{"Elu"}) { 1188 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0); 1189 } else if (fused_ops == std::vector<string>{"LeakyRelu"}) { 1190 float leakyrelu_alpha; 1191 OP_REQUIRES_OK(context, 1192 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 1193 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu, 1194 leakyrelu_alpha); 1195 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { 1196 this->set_fuse_biasadd(true); 1197 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu); 1198 OP_REQUIRES(context, num_args == 1, 1199 errors::InvalidArgument( 1200 "Fused Conv2D must have one extra argument: bias.")); 1201 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { 1202 this->set_fuse_biasadd(true); 1203 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu, 1204 6.0); 1205 OP_REQUIRES(context, num_args == 1, 1206 errors::InvalidArgument( 1207 "Fused Conv2D must have one extra argument: bias.")); 1208 } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { 1209 this->set_fuse_biasadd(true); 1210 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0); 1211 OP_REQUIRES(context, num_args == 1, 1212 errors::InvalidArgument( 1213 "Fused Conv2D must have one extra argument: bias.")); 1214 } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) { 1215 this->set_fuse_biasadd(true); 1216 float leakyrelu_alpha; 1217 OP_REQUIRES_OK(context, 1218 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 1219 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu, 1220 leakyrelu_alpha); 1221 OP_REQUIRES(context, num_args == 1, 1222 errors::InvalidArgument( 1223 "Fused Conv2D must have one extra argument: bias.")); 1224 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) { 1225 this->set_fuse_biasadd(true); 1226 this->set_fuse_add(true); 1227 OP_REQUIRES( 1228 context, num_args == 2, 1229 errors::InvalidArgument( 1230 "Fused Conv2D must have two extra arguments: bias and add.")); 1231 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) { 1232 this->set_fuse_biasadd(true); 1233 this->set_fuse_add(true); 1234 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu); 1235 OP_REQUIRES( 1236 context, num_args == 2, 1237 errors::InvalidArgument( 1238 "Fused Conv2D must have two extra arguments: bias and add.")); 1239 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) { 1240 this->set_fuse_biasadd(true); 1241 this->set_fuse_add(true); 1242 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu, 1243 6.0); 1244 OP_REQUIRES( 1245 context, num_args == 2, 1246 errors::InvalidArgument( 1247 "Fused Conv2D must have two extra arguments: bias and add.")); 1248 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) { 1249 this->set_fuse_biasadd(true); 1250 this->set_fuse_add(true); 1251 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0); 1252 OP_REQUIRES( 1253 context, num_args == 2, 1254 errors::InvalidArgument( 1255 "Fused Conv2D must have two extra arguments: bias and add.")); 1256 } else if (fused_ops == 1257 std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) { 1258 this->set_fuse_biasadd(true); 1259 this->set_fuse_add(true); 1260 float leakyrelu_alpha; 1261 OP_REQUIRES_OK(context, 1262 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 1263 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu, 1264 leakyrelu_alpha); 1265 OP_REQUIRES( 1266 context, num_args == 2, 1267 errors::InvalidArgument( 1268 "Fused Conv2D must have two extra arguments: bias and add.")); 1269 } else { 1270 OP_REQUIRES(context, false, 1271 errors::Unimplemented("Fusion is not implemented: [", 1272 absl::StrJoin(fused_ops, ","), "]")); 1273 } 1274 1275 if (pad_enabled) { 1276 this->set_fuse_pad(true); 1277 } 1278 } 1279 ~MklFusedConvOp()1280 virtual ~MklFusedConvOp() {} 1281 }; 1282 1283 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 1284 typename Toutput, typename Ttemp_output, typename Tpadding, 1285 bool pad_enabled, bool bias_enabled, bool is_depthwise, 1286 bool native_format> 1287 class MklFusedDepthwiseConvOp 1288 : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1289 Tpadding, bias_enabled, false, is_depthwise, 1290 native_format> { 1291 public: MklFusedDepthwiseConvOp(OpKernelConstruction * context)1292 explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context) 1293 : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1294 Tpadding, bias_enabled, false, is_depthwise, native_format>( 1295 context) { 1296 // Since we came here through the registration of 1297 // _MklFusedDepthwiseConv2dNative, get all 1298 // information from 'fused_ops' and 'num_args' 1299 std::vector<string> fused_ops; 1300 OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); 1301 1302 int num_args; 1303 OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); 1304 OP_REQUIRES(context, !fused_ops.empty(), 1305 errors::InvalidArgument( 1306 "Fused DepthwiseConv2D must have at least one fused op.")); 1307 1308 if (fused_ops == std::vector<string>{"BiasAdd"}) { 1309 this->set_fuse_biasadd(true); 1310 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { 1311 this->set_fuse_biasadd(true); 1312 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu); 1313 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { 1314 this->set_fuse_biasadd(true); 1315 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu, 1316 6.0); 1317 } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { 1318 this->set_fuse_biasadd(true); 1319 this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0); 1320 } else { 1321 OP_REQUIRES(context, false, 1322 errors::Unimplemented("Fusion is not implemented: [", 1323 absl::StrJoin(fused_ops, ","), "]")); 1324 } 1325 1326 OP_REQUIRES( 1327 context, num_args == 1, 1328 errors::InvalidArgument( 1329 "Fused DepthwiseConv2D must have one extra argument: bias.")); 1330 1331 if (pad_enabled) { 1332 this->set_fuse_pad(true); 1333 } 1334 } 1335 ~MklFusedDepthwiseConvOp()1336 virtual ~MklFusedDepthwiseConvOp() {} 1337 }; 1338 1339 // We create new class for each version of Quantized Convolution and inherit 1340 // from the FP32 version of the base class 1341 template <typename Device, typename Tinput, typename Tbias, typename Toutput, 1342 typename Ttemp_output, bool bias_enabled, bool is_depthwise> 1343 class MklQuantizedConv2DOp 1344 : public MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, 1345 int32, bias_enabled, false, is_depthwise, false> { 1346 public: ~MklQuantizedConv2DOp()1347 virtual ~MklQuantizedConv2DOp() { 1348 if (this->input_bias_ != nullptr) { 1349 delete this->input_bias_; 1350 input_bias_ = nullptr; 1351 } 1352 1353 if (this->scaled_bias_ != nullptr) { 1354 delete this->scaled_bias_; 1355 scaled_bias_ = nullptr; 1356 } 1357 } 1358 MklQuantizedConv2DOp(OpKernelConstruction * context)1359 explicit MklQuantizedConv2DOp(OpKernelConstruction* context) 1360 : MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 1361 bias_enabled, false, is_depthwise, false>(context) { 1362 bool is_filter_const; 1363 OP_REQUIRES_OK(context, 1364 context->GetAttr("is_filter_const", &is_filter_const)); 1365 1366 if (bias_enabled) { 1367 OP_REQUIRES_OK(context, 1368 context->GetAttr("is_bias_const", &is_bias_const_)); 1369 } 1370 1371 OP_REQUIRES(context, is_filter_const, 1372 errors::InvalidArgument("Filter must be a constant")); 1373 } 1374 Compute(OpKernelContext * context)1375 void Compute(OpKernelContext* context) override { 1376 // Compute int32 output tensor 1377 MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 1378 bias_enabled, false, is_depthwise, false>::Compute(context); 1379 1380 // Compute additional outputs: min/max scalars. 1381 int bias_index_offset; 1382 bias_index_offset = bias_enabled ? 1 : 0; 1383 1384 const float min_input = 1385 context->input(2 + bias_index_offset).flat<float>()(0); 1386 const float max_input = 1387 context->input(3 + bias_index_offset).flat<float>()(0); 1388 1389 MklDnnShape output_min_mkl_shape, output_max_mkl_shape; 1390 output_min_mkl_shape.SetMklTensor(false); 1391 output_max_mkl_shape.SetMklTensor(false); 1392 1393 Tensor* output_min = nullptr; 1394 Tensor* output_max = nullptr; 1395 if (std::is_same<Toutput, quint8>::value || 1396 std::is_same<Toutput, qint8>::value) { 1397 AllocateOutputSetMklShape(context, 1, &output_min, {}, 1398 output_min_mkl_shape); 1399 AllocateOutputSetMklShape(context, 2, &output_max, {}, 1400 output_max_mkl_shape); 1401 // This is the case the convolution and requantization are fused. 1402 output_min->flat<float>()(0) = 1403 context->input(6 + bias_index_offset).flat<float>()(0); 1404 output_max->flat<float>()(0) = 1405 context->input(7 + bias_index_offset).flat<float>()(0); 1406 } else { 1407 const Tensor& min_filter = context->input(4 + bias_index_offset); 1408 const Tensor& max_filter = context->input(5 + bias_index_offset); 1409 if (min_filter.dims() == 0) { 1410 float min_output_value; 1411 float max_output_value; 1412 MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>( 1413 min_input, max_input, min_filter.flat<float>()(0), 1414 max_filter.flat<float>()(0), &min_output_value, &max_output_value); 1415 AllocateOutputSetMklShape(context, 1, &output_min, {}, 1416 output_min_mkl_shape); 1417 AllocateOutputSetMklShape(context, 2, &output_max, {}, 1418 output_max_mkl_shape); 1419 output_min->flat<float>()(0) = min_output_value; 1420 output_max->flat<float>()(0) = max_output_value; 1421 } else { 1422 size_t depth = min_filter.NumElements(); 1423 AllocateOutputSetMklShape(context, 1, &output_min, 1424 {static_cast<ptrdiff_t>(depth)}, 1425 output_min_mkl_shape); 1426 AllocateOutputSetMklShape(context, 2, &output_max, 1427 {static_cast<ptrdiff_t>(depth)}, 1428 output_max_mkl_shape); 1429 MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>( 1430 min_input, max_input, min_filter, max_filter, &output_min, 1431 &output_max); 1432 } 1433 } 1434 } 1435 1436 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1437 void ExtendConvFwdParams(OpKernelContext* context, 1438 MklConvFwdParams& params) override { 1439 MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 1440 bias_enabled, false, is_depthwise, 1441 false>::ExtendConvFwdParams(context, params); 1442 1443 // When the output type is quint8, the output data id requantized 1444 // into quint8. A post_op "output_scale" is added to do the conversion. 1445 if (std::is_same<Toutput, quint8>::value || 1446 std::is_same<Toutput, qint8>::value) { 1447 int bias_index_offset; 1448 bias_index_offset = bias_enabled ? 1 : 0; 1449 1450 const float min_input = 1451 context->input(2 + bias_index_offset).flat<float>()(0); 1452 const float max_input = 1453 context->input(3 + bias_index_offset).flat<float>()(0); 1454 const Tensor& min_filter_vector = context->input(4 + bias_index_offset); 1455 const Tensor& max_filter_vector = context->input(5 + bias_index_offset); 1456 1457 // min_freezed_output and max_freezed_output are the actual range 1458 // for the output. 1459 const float min_freezed_output = 1460 context->input(6 + bias_index_offset).flat<float>()(0); 1461 const float max_freezed_output = 1462 context->input(7 + bias_index_offset).flat<float>()(0); 1463 1464 float int_output_limit = 1465 std::is_same<Toutput, quint8>::value ? 255.0f : 127.0f; 1466 size_t depth = min_filter_vector.NumElements(); 1467 const float* min_filter = min_filter_vector.flat<float>().data(); 1468 const float* max_filter = max_filter_vector.flat<float>().data(); 1469 std::vector<float> scales(depth); 1470 float float_input_range = 1471 std::max(std::abs(min_input), std::abs(max_input)); 1472 float float_output_range = 1473 std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); 1474 const float int_const_scale_limit = 1475 (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; 1476 for (size_t i = 0; i < depth; ++i) { 1477 // For simplicity and symmetry, we set filter range to be outer 1478 // bounds of min_filter and max_filter. 1479 float float_filter_range = 1480 std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); 1481 // To understand the scaling, please see mkl_requantize_ops_test. 1482 scales[i] = int_output_limit * float_input_range * float_filter_range / 1483 (int_const_scale_limit * float_output_range); 1484 } 1485 // we are creating a partial key here to use with primitive key caching to 1486 // improve key creation performance. Instead of using actual values we are 1487 // using the pointers for min/max_filter_vector, and this works since the 1488 // filter vector here is a constant. 1489 FactoryKeyCreator param_key; 1490 param_key.AddAsKey<float>(min_input); 1491 param_key.AddAsKey<float>(max_input); 1492 param_key.AddAsKey<float>(min_freezed_output); 1493 param_key.AddAsKey<float>(max_freezed_output); 1494 param_key.AddAsKey<const float*>(min_filter); 1495 param_key.AddAsKey<const float*>(max_filter); 1496 params.post_op_params.push_back({"output_scale", mkldnn::algorithm::undef, 1497 scales, param_key.GetKey()}); 1498 } 1499 } 1500 GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv_fwd_pd,const Tensor & bias_tensor)1501 Tbias* GetBiasHandle(OpKernelContext* context, 1502 std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1503 const Tensor& bias_tensor) override { 1504 if (!bias_enabled) { 1505 return nullptr; 1506 } 1507 if (std::is_same<Tbias, qint32>::value) { 1508 return static_cast<Tbias*>( 1509 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1510 } 1511 int bias_index_offset; 1512 bias_index_offset = bias_enabled ? 1 : 0; 1513 1514 const float min_input = 1515 context->input(2 + bias_index_offset).flat<float>()(0); 1516 const float max_input = 1517 context->input(3 + bias_index_offset).flat<float>()(0); 1518 const Tensor& min_filter_vector = context->input(4 + bias_index_offset); 1519 const Tensor& max_filter_vector = context->input(5 + bias_index_offset); 1520 const float* min_filter = min_filter_vector.flat<float>().data(); 1521 const float* max_filter = max_filter_vector.flat<float>().data(); 1522 1523 const float int_const_scale_limit = 1524 (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; 1525 // Re-scale bias if either of following 2 conditions are met: 1526 // 1. Bias is not const; 1527 // 2. Bias is const, but bias cache is empty (first iteration). 1528 1529 size_t depth = min_filter_vector.NumElements(); 1530 bool scales_are_valid = (depth == scales_.size()); 1531 scales_.resize(depth); 1532 for (size_t i = 0; i < depth; ++i) { 1533 float tmp_scale = 1534 int_const_scale_limit / 1535 (std::max(std::abs(max_input), std::abs(min_input)) * 1536 std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); 1537 if (scales_are_valid && std::abs(tmp_scale - scales_[i]) > 1e-6) { 1538 scales_are_valid = false; 1539 } 1540 scales_[i] = tmp_scale; 1541 } 1542 if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) { 1543 mkldnn::primitive_attr bias_attr; 1544 if (depth == 1) { 1545 bias_attr.set_output_scales(0, scales_); 1546 } else { 1547 bias_attr.set_output_scales(1, scales_); 1548 } 1549 1550 auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())}, 1551 MklDnnType<Tbias>(), memory::format_tag::x); 1552 void* bias_buf = static_cast<void*>( 1553 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1554 if (!input_bias_) { 1555 input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf); 1556 } else { 1557 input_bias_->set_data_handle(bias_buf); 1558 } 1559 1560 if (!scaled_bias_buf_) 1561 AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_, 1562 conv_fwd_pd->bias_desc(), &scaled_bias_buf_); 1563 if (!scaled_bias_) { 1564 scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_); 1565 } else { 1566 scaled_bias_->set_data_handle(scaled_bias_buf_); 1567 } 1568 auto reorder_desc = 1569 ReorderPd(this->cpu_engine_, input_bias_->get_desc(), 1570 this->cpu_engine_, scaled_bias_->get_desc(), bias_attr); 1571 CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, 1572 this->cpu_engine_, context); 1573 1574 Tbias* bias_data = 1575 reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle()); 1576 if (is_bias_const_) 1577 CacheBias(context, conv_fwd_pd, bias_data, scaled_bias_); 1578 1579 return bias_data; 1580 } 1581 return GetCachedBias(context); 1582 } 1583 1584 bool is_bias_const_; 1585 PersistentTensor cached_bias_data_ptensor_ TF_GUARDED_BY(bias_cache_mu_); 1586 1587 memory* input_bias_ = nullptr; 1588 memory* scaled_bias_ = nullptr; 1589 1590 Tensor scaled_bias_tensor_; 1591 void* scaled_bias_buf_ = nullptr; 1592 1593 private: 1594 std::vector<float> scales_; 1595 mutex bias_cache_mu_; 1596 // Allocate persistent tensors for cached bias data and 1597 // cached bias memory descriptor (data format) AllocatePersistentTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** bias_tensor)1598 void AllocatePersistentTensor(OpKernelContext* context, 1599 const ConvFwdPd& conv_prim_desc, 1600 Tensor** bias_tensor) { 1601 DCHECK(bias_tensor); 1602 TensorShape bias_tf_shape; 1603 bias_tf_shape.AddDim( 1604 (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias))); 1605 OP_REQUIRES_OK(context, context->allocate_persistent( 1606 DataTypeToEnum<Tbias>::value, bias_tf_shape, 1607 &cached_bias_data_ptensor_, bias_tensor)); 1608 } 1609 1610 // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot 1611 // be acquired before entering the function, since it is acquired 1612 // inside the function. IsBiasCacheEmpty(OpKernelContext * context)1613 inline bool IsBiasCacheEmpty(OpKernelContext* context) 1614 TF_LOCKS_EXCLUDED(bias_cache_mu_) { 1615 tf_shared_lock lock(bias_cache_mu_); 1616 return (cached_bias_data_ptensor_.NumElements() == 0); 1617 } 1618 1619 // Cache the converted bias in a persistent tensor. 1620 // Only one thread can execute this method at any given time. CacheBias(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tbias * bias_data,const memory * scaled_bias)1621 void CacheBias(OpKernelContext* context, 1622 const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1623 Tbias* bias_data, const memory* scaled_bias) 1624 TF_LOCKS_EXCLUDED(bias_cache_mu_) { 1625 mutex_lock lock(bias_cache_mu_); 1626 1627 // If bias is already cached, there's nothing to do. 1628 if (cached_bias_data_ptensor_.NumElements() > 0) { 1629 return; 1630 } 1631 1632 // Otherwise, cache bias 1633 Tensor* bias_tensor_ptr = nullptr; 1634 AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr); 1635 void* cached_bias_data = const_cast<void*>( 1636 static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data())); 1637 size_t cached_bias_data_size = scaled_bias->get_desc().get_size(); 1638 memcpy(cached_bias_data, bias_data, cached_bias_data_size); 1639 } 1640 GetCachedBias(OpKernelContext * context)1641 Tbias* GetCachedBias(OpKernelContext* context) 1642 TF_LOCKS_EXCLUDED(bias_cache_mu_) { 1643 tf_shared_lock lock(bias_cache_mu_); 1644 const Tensor& cached_bias_data = 1645 *cached_bias_data_ptensor_.AccessTensor(context); 1646 1647 return static_cast<Tbias*>( 1648 const_cast<Tbias*>(cached_bias_data.flat<Tbias>().data())); 1649 } 1650 }; 1651 1652 template <typename Device, typename Tinput, typename Tbias, typename Toutput, 1653 typename Ttemp_output, bool bias_enabled, bool is_depthwise> 1654 class MklQuantizedConv2DReluOp 1655 : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1656 bias_enabled, is_depthwise> { 1657 public: ~MklQuantizedConv2DReluOp()1658 virtual ~MklQuantizedConv2DReluOp() {} 1659 MklQuantizedConv2DReluOp(OpKernelConstruction * context)1660 explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context) 1661 : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1662 bias_enabled, is_depthwise>(context) {} 1663 1664 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1665 void ExtendConvFwdParams(OpKernelContext* context, 1666 MklConvFwdParams& params) override { 1667 MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1668 bias_enabled, 1669 is_depthwise>::ExtendConvFwdParams(context, params); 1670 1671 params.post_op_params.push_back( 1672 {"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""}); 1673 } 1674 }; 1675 1676 template <typename Device, typename Tinput, typename Tbias, typename Toutput, 1677 typename Ttemp_output, bool bias_enabled, bool is_depthwise> 1678 class MklQuantizedConv2DSumReluOp 1679 : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1680 bias_enabled, is_depthwise> { 1681 public: ~MklQuantizedConv2DSumReluOp()1682 virtual ~MklQuantizedConv2DSumReluOp() {} 1683 MklQuantizedConv2DSumReluOp(OpKernelConstruction * context)1684 explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context) 1685 : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1686 bias_enabled, is_depthwise>(context) {} 1687 1688 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1689 void ExtendConvFwdParams(OpKernelContext* context, 1690 MklConvFwdParams& params) override { 1691 MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1692 bias_enabled, 1693 is_depthwise>::ExtendConvFwdParams(context, params); 1694 // Calculate the scale (beta in mkldnn api term) for sum 1695 if (std::is_same<Toutput, quint8>::value) { 1696 int summand_idx = context->num_inputs() / 2 - 1 - 2; 1697 DataType summand_type = this->input_type(summand_idx); 1698 bool summand_condition = 1699 (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); 1700 CHECK((summand_condition)); 1701 int bias_index_offset = bias_enabled ? 1 : 0; 1702 const float min_freezed_output = 1703 context->input(6 + bias_index_offset).flat<float>()(0); 1704 const float max_freezed_output = 1705 context->input(7 + bias_index_offset).flat<float>()(0); 1706 const float min_freezed_summand = 1707 context->input(9 + bias_index_offset).flat<float>()(0); 1708 const float max_freezed_summand = 1709 context->input(10 + bias_index_offset).flat<float>()(0); 1710 1711 float scale_output = 1712 std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); 1713 float scale_summand = std::max(std::abs(min_freezed_summand), 1714 std::abs(max_freezed_summand)); 1715 // if summand_type is also DT_QUINT8 as the scale_output, 1716 // the scaling factor of 255.0f cancels each other and thus is avoided. 1717 // If it is not then it is DT_INT8 and is scaled appropriately. 1718 if (summand_type == DT_QUINT8) { 1719 params.post_op_params.push_back({"sum", 1720 mkldnn::algorithm::undef, 1721 {scale_summand / scale_output}, 1722 ""}); 1723 } else { 1724 params.post_op_params.push_back( 1725 {"sum", 1726 mkldnn::algorithm::undef, 1727 {255.0f * scale_summand / (scale_output * 127.0f)}, 1728 ""}); 1729 } 1730 } else { 1731 params.post_op_params.push_back( 1732 {"sum", mkldnn::algorithm::undef, {1.0}, ""}); 1733 } 1734 params.post_op_params.push_back( 1735 {"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""}); 1736 } 1737 AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)1738 void AllocateOutputTensor(OpKernelContext* context, 1739 const ConvFwdPd& conv_prim_desc, 1740 const memory::dims& output_dims_mkl_order, 1741 MklTensorFormat output_tf_format, 1742 MklDnnShape* output_mkl_shape, 1743 Tensor** output_tensor) override { 1744 int summand_idx = context->num_inputs() / 2 - 1; 1745 if (std::is_same<Toutput, quint8>::value) { 1746 summand_idx -= 2; 1747 DataType summand_type = this->input_type(summand_idx); 1748 bool summand_condition = 1749 (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); 1750 CHECK((summand_condition)); 1751 Tensor& summand = const_cast<Tensor&>(MklGetInput(context, summand_idx)); 1752 MklDnnShape summand_mkl_shape; 1753 GetMklShape(context, summand_idx, &summand_mkl_shape); 1754 auto dst_md = summand_mkl_shape.GetMklLayout(); 1755 1756 // TODO(intel-tf): Handle both non-MKL and MKL tensors 1757 if (summand_type == DT_QINT8) { 1758 OP_REQUIRES_OK( 1759 context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape())); 1760 dst_md.data.data_type = 1761 static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>()); 1762 summand_mkl_shape.SetMklLayout(&dst_md); 1763 summand_mkl_shape.SetElemType(MklDnnType<Toutput>()); 1764 } 1765 // TODO(intel-tf): Support cases when summand cannot be forwarded. 1766 OP_REQUIRES( 1767 context, 1768 ForwardMklTensorInToOutWithMklShape( 1769 context, summand_idx, 0, output_tensor, summand_mkl_shape, false), 1770 errors::InvalidArgument( 1771 "Summand cannot be forwarded in the current fusion.")); 1772 return; 1773 } 1774 MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 1775 bias_enabled, false, false, 1776 false>::AllocateOutputTensor(context, conv_prim_desc, 1777 output_dims_mkl_order, 1778 output_tf_format, output_mkl_shape, 1779 output_tensor); 1780 const Tensor& summand = MklGetInput(context, summand_idx); 1781 if (summand.dtype() != DT_FLOAT) 1782 TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, 1783 "Current fusion requires summand to be float")); 1784 MklDnnShape summand_mkl_shape; 1785 GetMklShape(context, summand_idx, &summand_mkl_shape); 1786 // We need to compute scale for the summand 1787 int bias_index_offset = bias_enabled ? 1 : 0; 1788 const float min_input = 1789 context->input(2 + bias_index_offset).flat<float>()(0); 1790 const float max_input = 1791 context->input(3 + bias_index_offset).flat<float>()(0); 1792 const Tensor& min_filter_vector = context->input(4 + bias_index_offset); 1793 const Tensor& max_filter_vector = context->input(5 + bias_index_offset); 1794 const float* min_filter = min_filter_vector.flat<float>().data(); 1795 const float* max_filter = max_filter_vector.flat<float>().data(); 1796 1797 const float int_const_scale_limit = 1798 (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; 1799 size_t depth = min_filter_vector.NumElements(); 1800 std::vector<float> scales(depth); 1801 for (size_t i = 0; i < depth; ++i) { 1802 // TODO(nammbash): scale factors for UINT8(inputs) & INT8(weights) are 1803 // done regularly. A Cleaner design to address all mapping in one 1804 // function needs to be implemented in future which also supports other 1805 // quantized type mapping in future. 1806 scales[i] = int_const_scale_limit / 1807 (std::max(std::abs(max_input), std::abs(min_input)) * 1808 std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); 1809 } 1810 mkldnn::primitive_attr reorder_attr; 1811 if (depth == 1) { 1812 reorder_attr.set_output_scales(0, scales); 1813 } else { 1814 reorder_attr.set_output_scales(2, scales); 1815 } 1816 auto summand_md = 1817 summand_mkl_shape.IsMklTensor() 1818 ? summand_mkl_shape.GetMklLayout() 1819 : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(), 1820 memory::format_tag::nhwc); 1821 void* summand_buf = 1822 static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data())); 1823 void* dst_buf = 1824 static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); 1825 summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf)); 1826 dst_.reset( 1827 new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf)); 1828 auto reorder_desc = 1829 ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_, 1830 conv_prim_desc.dst_desc(), reorder_attr); 1831 CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_, 1832 context); 1833 } 1834 1835 std::shared_ptr<mkldnn::memory> summand_; 1836 std::shared_ptr<mkldnn::memory> dst_; 1837 }; 1838 1839 // INT8 kernel registration 1840 // Register NoOp kernel for QuantizedConv2D for qint8 filter 1841 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2D") 1842 .Device(DEVICE_CPU) 1843 .TypeConstraint<quint8>("Tinput") 1844 .TypeConstraint<qint8>("Tfilter") 1845 .TypeConstraint<qint32>("out_type"), 1846 NoOp); 1847 1848 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize") 1849 .Device(DEVICE_CPU) 1850 .TypeConstraint<quint8>("Tinput") 1851 .TypeConstraint<qint8>("Tfilter") 1852 .TypeConstraint<qint8>("out_type"), 1853 NoOp); 1854 1855 // Register NoOp kernel for QuantizedConv2DPerChannel. 1856 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DPerChannel") 1857 .Device(DEVICE_CPU) 1858 .TypeConstraint<quint8>("Tinput") 1859 .TypeConstraint<qint8>("Tfilter") 1860 .TypeConstraint<qint32>("out_type"), 1861 NoOp); 1862 // Register a templatized implementation of MklQuantizedConv2DPerChannel. 1863 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DPerChannel") 1864 .Device(DEVICE_CPU) 1865 .TypeConstraint<quint8>("Tinput") 1866 .TypeConstraint<qint8>("Tfilter") 1867 .TypeConstraint<qint32>("out_type") 1868 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1869 MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32, 1870 qint32, false, false>); 1871 1872 // Register a templatized implementation of MklQuantizedConv2D. 1873 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2D") 1874 .Device(DEVICE_CPU) 1875 .TypeConstraint<quint8>("Tinput") 1876 .TypeConstraint<qint8>("Tfilter") 1877 .TypeConstraint<qint32>("out_type") 1878 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1879 MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32, 1880 qint32, false, false>); 1881 1882 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2D") 1883 .Device(DEVICE_CPU) 1884 .TypeConstraint<qint8>("Tinput") 1885 .TypeConstraint<qint8>("Tfilter") 1886 .TypeConstraint<qint32>("out_type") 1887 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1888 MklQuantizedConv2DOp<CPUDevice, qint8, float, qint32, 1889 qint32, false, false>); 1890 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndRequantize") 1891 .Device(DEVICE_CPU) 1892 .TypeConstraint<quint8>("Tinput") 1893 .TypeConstraint<qint8>("Tfilter") 1894 .TypeConstraint<qint8>("out_type") 1895 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1896 MklQuantizedConv2DOp<CPUDevice, quint8, qint32, qint8, 1897 qint8, false, false>); 1898 1899 // Register NoOp kernel for QuantizedConv2DWithBias to get a python interface. 1900 // This kernel will be replaced by an MKL kernel during graph 1901 // optimization pass. 1902 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBias") 1903 .Device(DEVICE_CPU) 1904 .TypeConstraint<quint8>("Tinput") 1905 .TypeConstraint<qint8>("Tfilter") 1906 .TypeConstraint<qint32>("out_type"), 1907 NoOp); 1908 1909 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize") 1910 .Device(DEVICE_CPU) 1911 .TypeConstraint<quint8>("Tinput") 1912 .TypeConstraint<qint8>("Tfilter") 1913 .TypeConstraint<qint8>("out_type"), 1914 NoOp); 1915 1916 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBias") 1917 .Device(DEVICE_CPU) 1918 .TypeConstraint<qint8>("Tinput") 1919 .TypeConstraint<qint8>("Tfilter") 1920 .TypeConstraint<qint32>("out_type"), 1921 NoOp); 1922 1923 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize") 1924 .Device(DEVICE_CPU) 1925 .TypeConstraint<qint8>("Tinput") 1926 .TypeConstraint<qint8>("Tfilter") 1927 .TypeConstraint<qint8>("out_type"), 1928 NoOp); 1929 // Register a templatized implementation MklQuantizedConv2DWithBias. 1930 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBias") 1931 .Device(DEVICE_CPU) 1932 .TypeConstraint<quint8>("Tinput") 1933 .TypeConstraint<qint8>("Tfilter") 1934 .TypeConstraint<qint32>("out_type") 1935 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1936 MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32, 1937 qint32, true, false>); 1938 1939 REGISTER_KERNEL_BUILDER( 1940 Name("_MklQuantizedConv2DWithBiasAndRequantize") 1941 .Device(DEVICE_CPU) 1942 .TypeConstraint<quint8>("Tinput") 1943 .TypeConstraint<qint8>("Tfilter") 1944 .TypeConstraint<qint32>("Tbias") 1945 .TypeConstraint<qint8>("out_type") 1946 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1947 MklQuantizedConv2DOp<CPUDevice, quint8, qint32, qint8, qint8, true, false>); 1948 1949 REGISTER_KERNEL_BUILDER( 1950 Name("_MklQuantizedConv2DWithBiasAndRequantize") 1951 .Device(DEVICE_CPU) 1952 .TypeConstraint<quint8>("Tinput") 1953 .TypeConstraint<qint8>("Tfilter") 1954 .TypeConstraint<float>("Tbias") 1955 .TypeConstraint<qint8>("out_type") 1956 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1957 MklQuantizedConv2DOp<CPUDevice, quint8, float, qint8, qint8, true, false>); 1958 1959 REGISTER_KERNEL_BUILDER( 1960 Name("_MklQuantizedConv2DWithBias") 1961 .Device(DEVICE_CPU) 1962 .TypeConstraint<qint8>("Tinput") 1963 .TypeConstraint<qint8>("Tfilter") 1964 .TypeConstraint<qint32>("out_type") 1965 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1966 MklQuantizedConv2DOp<CPUDevice, qint8, float, qint32, qint32, true, false>); 1967 1968 REGISTER_KERNEL_BUILDER( 1969 Name("_MklQuantizedConv2DWithBiasAndRequantize") 1970 .Device(DEVICE_CPU) 1971 .TypeConstraint<qint8>("Tinput") 1972 .TypeConstraint<qint8>("Tfilter") 1973 .TypeConstraint<qint32>("Tbias") 1974 .TypeConstraint<qint8>("out_type") 1975 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1976 MklQuantizedConv2DOp<CPUDevice, qint8, qint32, qint8, qint8, true, false>); 1977 1978 REGISTER_KERNEL_BUILDER( 1979 Name("_MklQuantizedConv2DWithBiasAndRequantize") 1980 .Device(DEVICE_CPU) 1981 .TypeConstraint<qint8>("Tinput") 1982 .TypeConstraint<qint8>("Tfilter") 1983 .TypeConstraint<float>("Tbias") 1984 .TypeConstraint<qint8>("out_type") 1985 .Label(mkl_op_registry::kMklQuantizedOpLabel), 1986 MklQuantizedConv2DOp<CPUDevice, qint8, float, qint8, qint8, true, false>); 1987 1988 // Register NoOp kernel for QuantizedConv2DAndRelu to get a python interface. 1989 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 1990 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRelu") 1991 .Device(DEVICE_CPU) 1992 .TypeConstraint<quint8>("Tinput") 1993 .TypeConstraint<qint8>("Tfilter") 1994 .TypeConstraint<qint32>("out_type"), 1995 NoOp); 1996 1997 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndReluAndRequantize") 1998 .Device(DEVICE_CPU) 1999 .TypeConstraint<quint8>("Tinput") 2000 .TypeConstraint<qint8>("Tfilter") 2001 .TypeConstraint<quint8>("out_type"), 2002 NoOp); 2003 2004 // Register a templatized implementation of MklQuantizedConv2DAndRelu. 2005 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndRelu") 2006 .Device(DEVICE_CPU) 2007 .TypeConstraint<quint8>("Tinput") 2008 .TypeConstraint<qint8>("Tfilter") 2009 .TypeConstraint<qint32>("out_type") 2010 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2011 MklQuantizedConv2DReluOp<CPUDevice, quint8, float, 2012 qint32, qint32, false, false>); 2013 2014 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndReluAndRequantize") 2015 .Device(DEVICE_CPU) 2016 .TypeConstraint<quint8>("Tinput") 2017 .TypeConstraint<qint8>("Tfilter") 2018 .TypeConstraint<quint8>("out_type") 2019 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2020 MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, 2021 quint8, quint8, false, false>); 2022 2023 // Register NoOp kernel for QuantizedConv2DWithBiasAndRelu to get a python 2024 // interface. 2025 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 2026 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu") 2027 .Device(DEVICE_CPU) 2028 .TypeConstraint<quint8>("Tinput") 2029 .TypeConstraint<qint8>("Tfilter") 2030 .TypeConstraint<qint32>("out_type"), 2031 NoOp); 2032 2033 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu") 2034 .Device(DEVICE_CPU) 2035 .TypeConstraint<qint8>("Tinput") 2036 .TypeConstraint<qint8>("Tfilter") 2037 .TypeConstraint<qint32>("out_type"), 2038 NoOp); 2039 2040 // Register NoOp kernel for QuantizedConv2DWithBiasAndReluAndRequantize 2041 // to get a python interface. 2042 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 2043 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize") 2044 .Device(DEVICE_CPU) 2045 .TypeConstraint<quint8>("Tinput") 2046 .TypeConstraint<qint8>("Tfilter") 2047 .TypeConstraint<quint8>("out_type"), 2048 NoOp); 2049 2050 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize") 2051 .Device(DEVICE_CPU) 2052 .TypeConstraint<qint8>("Tinput") 2053 .TypeConstraint<qint8>("Tfilter") 2054 .TypeConstraint<quint8>("out_type"), 2055 NoOp); 2056 // Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu. 2057 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndRelu") 2058 .Device(DEVICE_CPU) 2059 .TypeConstraint<quint8>("Tinput") 2060 .TypeConstraint<qint8>("Tfilter") 2061 .TypeConstraint<qint32>("out_type") 2062 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2063 MklQuantizedConv2DReluOp<CPUDevice, quint8, float, 2064 qint32, qint32, true, false>); 2065 2066 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndRelu") 2067 .Device(DEVICE_CPU) 2068 .TypeConstraint<qint8>("Tinput") 2069 .TypeConstraint<qint8>("Tfilter") 2070 .TypeConstraint<qint32>("out_type") 2071 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2072 MklQuantizedConv2DReluOp<CPUDevice, qint8, float, 2073 qint32, qint32, true, false>); 2074 // Register a templatized implementation of 2075 // MklQuantizedConv2DWithBiasAndReluAndRequantize. 2076 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") 2077 .Device(DEVICE_CPU) 2078 .TypeConstraint<quint8>("Tinput") 2079 .TypeConstraint<qint8>("Tfilter") 2080 .TypeConstraint<float>("Tbias") 2081 .TypeConstraint<quint8>("out_type") 2082 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2083 MklQuantizedConv2DReluOp<CPUDevice, quint8, float, 2084 quint8, quint8, true, false>); 2085 2086 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") 2087 .Device(DEVICE_CPU) 2088 .TypeConstraint<quint8>("Tinput") 2089 .TypeConstraint<qint8>("Tfilter") 2090 .TypeConstraint<qint32>("Tbias") 2091 .TypeConstraint<quint8>("out_type") 2092 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2093 MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, 2094 quint8, quint8, true, false>); 2095 2096 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") 2097 .Device(DEVICE_CPU) 2098 .TypeConstraint<qint8>("Tinput") 2099 .TypeConstraint<qint8>("Tfilter") 2100 .TypeConstraint<float>("Tbias") 2101 .TypeConstraint<quint8>("out_type") 2102 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2103 MklQuantizedConv2DReluOp<CPUDevice, qint8, float, 2104 quint8, quint8, true, false>); 2105 2106 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") 2107 .Device(DEVICE_CPU) 2108 .TypeConstraint<qint8>("Tinput") 2109 .TypeConstraint<qint8>("Tfilter") 2110 .TypeConstraint<qint32>("Tbias") 2111 .TypeConstraint<quint8>("out_type") 2112 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2113 MklQuantizedConv2DReluOp<CPUDevice, qint8, qint32, 2114 quint8, quint8, true, false>); 2115 2116 // Register NoOp kernel for QuantizedConv2DWithBiasSumAndRelu to get a python 2117 // interface. 2118 // This kernel will be replaced by an MKL kernel during graph-optimization pass. 2119 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndRelu") 2120 .Device(DEVICE_CPU) 2121 .TypeConstraint<quint8>("Tinput") 2122 .TypeConstraint<qint8>("Tfilter") 2123 .TypeConstraint<qint32>("out_type"), 2124 NoOp); 2125 2126 REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndReluAndRequantize") 2127 .Device(DEVICE_CPU) 2128 .TypeConstraint<quint8>("Tinput") 2129 .TypeConstraint<qint8>("Tfilter") 2130 .TypeConstraint<quint8>("out_type"), 2131 NoOp); 2132 2133 REGISTER_KERNEL_BUILDER( 2134 Name("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") 2135 .Device(DEVICE_CPU) 2136 .TypeConstraint<quint8>("Tinput") 2137 .TypeConstraint<qint8>("Tfilter") 2138 .TypeConstraint<quint8>("out_type"), 2139 NoOp); 2140 2141 // Register a templatized implementation of 2142 // MklQuantizedConv2DWithBiasSumAndRelu. 2143 REGISTER_KERNEL_BUILDER( 2144 Name("_MklQuantizedConv2DWithBiasSumAndRelu") 2145 .Device(DEVICE_CPU) 2146 .TypeConstraint<quint8>("Tinput") 2147 .TypeConstraint<qint8>("Tfilter") 2148 .TypeConstraint<qint32>("out_type") 2149 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2150 MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, qint32, qint32, true, 2151 false>); 2152 2153 REGISTER_KERNEL_BUILDER( 2154 Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") 2155 .Device(DEVICE_CPU) 2156 .TypeConstraint<quint8>("Tinput") 2157 .TypeConstraint<qint8>("Tfilter") 2158 .TypeConstraint<qint32>("Tbias") 2159 .TypeConstraint<quint8>("out_type") 2160 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2161 MklQuantizedConv2DSumReluOp<CPUDevice, quint8, qint32, quint8, quint8, true, 2162 false>); 2163 2164 REGISTER_KERNEL_BUILDER( 2165 Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") 2166 .Device(DEVICE_CPU) 2167 .TypeConstraint<quint8>("Tinput") 2168 .TypeConstraint<qint8>("Tfilter") 2169 .TypeConstraint<qint32>("Tbias") 2170 .TypeConstraint<quint8>("out_type") 2171 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2172 MklQuantizedConv2DSumReluOp<CPUDevice, quint8, qint32, quint8, qint8, true, 2173 false>); 2174 2175 REGISTER_KERNEL_BUILDER( 2176 Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") 2177 .Device(DEVICE_CPU) 2178 .TypeConstraint<quint8>("Tinput") 2179 .TypeConstraint<qint8>("Tfilter") 2180 .TypeConstraint<float>("Tbias") 2181 .TypeConstraint<quint8>("out_type") 2182 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2183 MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, quint8, quint8, true, 2184 false>); 2185 2186 REGISTER_KERNEL_BUILDER( 2187 Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") 2188 .Device(DEVICE_CPU) 2189 .TypeConstraint<quint8>("Tinput") 2190 .TypeConstraint<qint8>("Tfilter") 2191 .TypeConstraint<float>("Tbias") 2192 .TypeConstraint<quint8>("out_type") 2193 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2194 MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, quint8, qint8, true, 2195 false>); 2196 2197 // Register NoOp kernels for non-fused and fused versions of 2198 // QuantizedDepthwiseConv2D to get a Python interface. These kernels will be 2199 // replaced by MKL kernels during the graph-optimization pass. 2200 REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2D") 2201 .Device(DEVICE_CPU) 2202 .TypeConstraint<quint8>("Tinput") 2203 .TypeConstraint<qint8>("Tfilter") 2204 .TypeConstraint<qint32>("out_type"), 2205 NoOp); 2206 2207 REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2DWithBias") 2208 .Device(DEVICE_CPU) 2209 .TypeConstraint<quint8>("Tinput") 2210 .TypeConstraint<qint8>("Tfilter") 2211 .TypeConstraint<qint32>("out_type"), 2212 NoOp); 2213 2214 REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2DWithBiasAndRelu") 2215 .Device(DEVICE_CPU) 2216 .TypeConstraint<quint8>("Tinput") 2217 .TypeConstraint<qint8>("Tfilter") 2218 .TypeConstraint<qint32>("out_type"), 2219 NoOp); 2220 2221 REGISTER_KERNEL_BUILDER( 2222 Name("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") 2223 .Device(DEVICE_CPU) 2224 .TypeConstraint<quint8>("Tinput") 2225 .TypeConstraint<qint8>("Tfilter") 2226 .TypeConstraint<quint8>("out_type"), 2227 NoOp); 2228 2229 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") 2230 .Device(DEVICE_CPU) 2231 .TypeConstraint<bfloat16>("T"), 2232 NoOp); 2233 2234 #define REGISTER_NO_OP_CPU_2D_DEPTHWISE(T) \ 2235 REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative") \ 2236 .Device(DEVICE_CPU) \ 2237 .TypeConstraint<T>("T"), \ 2238 NoOp); 2239 2240 TF_CALL_float(REGISTER_NO_OP_CPU_2D_DEPTHWISE); 2241 TF_CALL_bfloat16(REGISTER_NO_OP_CPU_2D_DEPTHWISE); 2242 2243 // Register templatized MKL kernels for non-fused and fused-versions of 2244 // QuantizedDepthwiseConv2D. 2245 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D") 2246 .Device(DEVICE_CPU) 2247 .TypeConstraint<quint8>("Tinput") 2248 .TypeConstraint<qint8>("Tfilter") 2249 .TypeConstraint<qint32>("out_type") 2250 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2251 MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32, 2252 qint32, false, true>); 2253 2254 REGISTER_KERNEL_BUILDER( 2255 Name("_MklQuantizedDepthwiseConv2DWithBias") 2256 .Device(DEVICE_CPU) 2257 .TypeConstraint<quint8>("Tinput") 2258 .TypeConstraint<qint8>("Tfilter") 2259 .TypeConstraint<qint32>("out_type") 2260 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2261 MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32, qint32, true, true>); 2262 2263 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2DWithBiasAndRelu") 2264 .Device(DEVICE_CPU) 2265 .TypeConstraint<quint8>("Tinput") 2266 .TypeConstraint<qint8>("Tfilter") 2267 .TypeConstraint<qint32>("out_type") 2268 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2269 MklQuantizedConv2DReluOp<CPUDevice, quint8, float, 2270 qint32, qint32, true, true>); 2271 2272 // Tbias -> float 2273 REGISTER_KERNEL_BUILDER( 2274 Name("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") 2275 .Device(DEVICE_CPU) 2276 .TypeConstraint<quint8>("Tinput") 2277 .TypeConstraint<qint8>("Tfilter") 2278 .TypeConstraint<float>("Tbias") 2279 .TypeConstraint<quint8>("out_type") 2280 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2281 MklQuantizedConv2DReluOp<CPUDevice, quint8, float, quint8, quint8, true, 2282 true>); 2283 2284 // Tbias -> qint32 2285 REGISTER_KERNEL_BUILDER( 2286 Name("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") 2287 .Device(DEVICE_CPU) 2288 .TypeConstraint<quint8>("Tinput") 2289 .TypeConstraint<qint8>("Tfilter") 2290 .TypeConstraint<qint32>("Tbias") 2291 .TypeConstraint<quint8>("out_type") 2292 .Label(mkl_op_registry::kMklQuantizedOpLabel), 2293 MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true, 2294 true>); 2295 2296 // Register 2D operations 2297 #define REGISTER_MKL_CPU_2D(T) \ 2298 REGISTER_KERNEL_BUILDER( \ 2299 Name("_MklConv2D") \ 2300 .Device(DEVICE_CPU) \ 2301 .TypeConstraint<T>("T") \ 2302 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2303 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \ 2304 REGISTER_KERNEL_BUILDER( \ 2305 Name("_MklConv2DWithBias") \ 2306 .Device(DEVICE_CPU) \ 2307 .TypeConstraint<T>("T") \ 2308 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2309 MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>); \ 2310 REGISTER_KERNEL_BUILDER( \ 2311 Name("__MklDummyConv2DWithBias") \ 2312 .Device(DEVICE_CPU) \ 2313 .TypeConstraint<T>("T") \ 2314 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2315 MklDummyOp<CPUDevice, T>); \ 2316 REGISTER_KERNEL_BUILDER( \ 2317 Name("_MklPadWithConv2D") \ 2318 .Device(DEVICE_CPU) \ 2319 .TypeConstraint<T>("T") \ 2320 .TypeConstraint<int32>("Tpaddings") \ 2321 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2322 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>); \ 2323 REGISTER_KERNEL_BUILDER( \ 2324 Name("_MklPadWithConv2D") \ 2325 .Device(DEVICE_CPU) \ 2326 .TypeConstraint<T>("T") \ 2327 .TypeConstraint<int64>("Tpaddings") \ 2328 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2329 MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>); \ 2330 REGISTER_KERNEL_BUILDER( \ 2331 Name("__MklDummyPadWithConv2D") \ 2332 .Device(DEVICE_CPU) \ 2333 .TypeConstraint<T>("T") \ 2334 .TypeConstraint<int32>("Tpaddings") \ 2335 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2336 MklDummyOp<CPUDevice, T>); \ 2337 REGISTER_KERNEL_BUILDER( \ 2338 Name("_MklNativeConv2D") \ 2339 .Device(DEVICE_CPU) \ 2340 .TypeConstraint<T>("T") \ 2341 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2342 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>); \ 2343 REGISTER_KERNEL_BUILDER( \ 2344 Name("_MklNativeConv2DWithBias") \ 2345 .Device(DEVICE_CPU) \ 2346 .TypeConstraint<T>("T") \ 2347 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2348 MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, true>); \ 2349 REGISTER_KERNEL_BUILDER( \ 2350 Name("_MklNativePadWithConv2D") \ 2351 .Device(DEVICE_CPU) \ 2352 .TypeConstraint<T>("T") \ 2353 .TypeConstraint<int32>("Tpaddings") \ 2354 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2355 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, true>); \ 2356 REGISTER_KERNEL_BUILDER( \ 2357 Name("_MklNativePadWithConv2D") \ 2358 .Device(DEVICE_CPU) \ 2359 .TypeConstraint<T>("T") \ 2360 .TypeConstraint<int64>("Tpaddings") \ 2361 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2362 MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, true>); 2363 2364 TF_CALL_float(REGISTER_MKL_CPU_2D); 2365 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D); 2366 2367 #define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \ 2368 REGISTER_KERNEL_BUILDER( \ 2369 Name("_MklDepthwiseConv2dNative") \ 2370 .Device(DEVICE_CPU) \ 2371 .TypeConstraint<T>("T") \ 2372 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2373 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>); \ 2374 REGISTER_KERNEL_BUILDER( \ 2375 Name("_MklFusedDepthwiseConv2dNative") \ 2376 .Device(DEVICE_CPU) \ 2377 .TypeConstraint<T>("T") \ 2378 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2379 MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true, \ 2380 true, false>); \ 2381 REGISTER_KERNEL_BUILDER( \ 2382 Name("_MklNativeFusedDepthwiseConv2dNative") \ 2383 .Device(DEVICE_CPU) \ 2384 .TypeConstraint<T>("T") \ 2385 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2386 MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true, \ 2387 true, true>); \ 2388 REGISTER_KERNEL_BUILDER( \ 2389 Name("_MklNativeDepthwiseConv2dNative") \ 2390 .Device(DEVICE_CPU) \ 2391 .TypeConstraint<T>("T") \ 2392 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2393 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, true>); 2394 2395 TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); 2396 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE); 2397 2398 // Note we are registering _MklFusedConv2D. 2399 // We check the fused_ops attributes to decide if bias is enabled or not. 2400 #define REGISTER_MKL_CPU_2D_FUSED(T) \ 2401 REGISTER_KERNEL_BUILDER( \ 2402 Name("_MklFusedConv2D") \ 2403 .Device(DEVICE_CPU) \ 2404 .TypeConstraint<T>("T") \ 2405 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2406 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, false>); \ 2407 REGISTER_KERNEL_BUILDER( \ 2408 Name("_MklPadWithFusedConv2D") \ 2409 .Device(DEVICE_CPU) \ 2410 .TypeConstraint<int32>("Tpaddings") \ 2411 .TypeConstraint<T>("T") \ 2412 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2413 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, false>); \ 2414 REGISTER_KERNEL_BUILDER( \ 2415 Name("_MklPadWithFusedConv2D") \ 2416 .Device(DEVICE_CPU) \ 2417 .TypeConstraint<T>("T") \ 2418 .TypeConstraint<int64>("Tpaddings") \ 2419 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2420 MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, false>); \ 2421 REGISTER_KERNEL_BUILDER( \ 2422 Name("__MklDummyPadWithFusedConv2D") \ 2423 .Device(DEVICE_CPU) \ 2424 .TypeConstraint<T>("T") \ 2425 .TypeConstraint<int32>("Tpaddings") \ 2426 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2427 MklDummyOp<CPUDevice, T>); \ 2428 REGISTER_KERNEL_BUILDER( \ 2429 Name("_MklNativeFusedConv2D") \ 2430 .Device(DEVICE_CPU) \ 2431 .TypeConstraint<T>("T") \ 2432 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2433 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, true>); \ 2434 REGISTER_KERNEL_BUILDER( \ 2435 Name("_MklNativePadWithFusedConv2D") \ 2436 .Device(DEVICE_CPU) \ 2437 .TypeConstraint<int32>("Tpaddings") \ 2438 .TypeConstraint<T>("T") \ 2439 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2440 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, true>); \ 2441 REGISTER_KERNEL_BUILDER( \ 2442 Name("_MklNativePadWithFusedConv2D") \ 2443 .Device(DEVICE_CPU) \ 2444 .TypeConstraint<T>("T") \ 2445 .TypeConstraint<int64>("Tpaddings") \ 2446 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2447 MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, true>); 2448 2449 TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED); 2450 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED); 2451 2452 // Register 3D operations 2453 #define REGISTER_MKL_CPU_3D(T) \ 2454 REGISTER_KERNEL_BUILDER( \ 2455 Name("_MklConv3D") \ 2456 .Device(DEVICE_CPU) \ 2457 .TypeConstraint<T>("T") \ 2458 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2459 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \ 2460 REGISTER_KERNEL_BUILDER( \ 2461 Name("_MklNativeConv3D") \ 2462 .Device(DEVICE_CPU) \ 2463 .TypeConstraint<T>("T") \ 2464 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2465 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>); 2466 TF_CALL_float(REGISTER_MKL_CPU_3D); 2467 TF_CALL_bfloat16(REGISTER_MKL_CPU_3D); 2468 2469 } // namespace tensorflow 2470 #endif // INTEL_MKL 2471